VisionBias-in-AI / pages /1_Clothing_Bias.py
ratneshpasi03's picture
Update pages/1_Clothing_Bias.py
d99e0b0 verified
import streamlit as st
import string as s
from PIL import Image
import requests
from io import BytesIO
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import matplotlib.pyplot as plt
# Load model
@st.cache_resource
def load_clip_model():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
return model, processor
# --- Main page content ---
st.header("Clothing Bias in Scene Classification")
st.markdown("""
This application explores biases in scene classification models related to clothing attributes.
It leverages the CLIP model to analyze and highlight these biases.
""")
input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)
image = None
if input_method == "Upload Image":
uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
elif input_method == "Use Image URL":
image_url = st.text_input("Paste an image URL")
if image_url:
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
except:
st.error("Couldn't load image from the provided URL.")
elif input_method == "Default Images":
st.subheader("Select a Default Image")
image_dir = "default_images/clothing_bias"
default_images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))])
selected_image = None
columns = st.columns(4) # Display images in 4 columns
for i, image_file in enumerate(default_images):
col = columns[i % 4]
img_path = os.path.join(image_dir, image_file)
with col:
st.image(img_path, caption=image_file, use_container_width=True)
if st.button(f"Select {s.ascii_uppercase[i]}", key=image_file):
selected_image = image_file
# Store selected image using session state so selection persists
if selected_image:
st.session_state.selected_image = selected_image
if "selected_image" in st.session_state:
image_path = os.path.join(image_dir, st.session_state.selected_image)
image = Image.open(image_path).convert("RGB")
st.success(f"Selected: {st.session_state.selected_image}")
# Show the image if loaded
if image is not None:
st.image(image, caption="Input Image", width=250)
# Prompt input
st.subheader("Candidate Scene Labels")
default_prompts = ["a business executive", "a festival participant", "a indian maid", "a school teacher"]
prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]
# Analyze button
if st.button("Analyze Image"):
if image is None:
st.warning("Please upload an image, paste a URL, or choose a default image before analysis.")
else:
model, processor = load_clip_model()
with st.spinner("Analyzing the image, please wait..."):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)[0]
# Show probabilities
st.subheader("Classification Probabilities")
data = {"Label": labels, "Probability": probs.numpy()}
df = pd.DataFrame(data)
df.index += 1 # Start index from 1
st.table(df)
st.write("**Most likely label**:", labels[probs.argmax().item()])
st.write("\n")
# Bar plot
fig, ax = plt.subplots(figsize=(6, 4))
ax.barh(labels, probs.numpy(), color='skyblue')
ax.set_xlim(0, 1)
ax.set_xlabel("Probability")
ax.set_title("Scene Classification")
st.pyplot(fig)