VisionBias-in-AI / pages /4_Gender_Bias_Occupation.py
ratneshpasi03's picture
Update pages/4_Gender_Bias_Occupation.py
115b8fb verified
import streamlit as st
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from PIL import Image
import requests
import os
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
# --- Caching models ---
@st.cache_resource
def load_blip_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model
@st.cache_resource
def load_clip_model():
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
return processor, model
# --- Analysis Functions ---
def analyze_with_blip(image):
st.subheader("BLIP Caption Generation")
processor, model = load_blip_model()
with st.spinner("Generating caption..."):
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
st.success("Caption generated!")
st.write(f"**Caption:** {caption}")
def analyze_with_clip(image):
st.subheader("CLIP Occupation Classification")
processor, model = load_clip_model()
texts = ["a nurse", "a doctor", "a scientist", "an engineer"]
with st.spinner("Classifying image..."):
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1).detach().numpy()[0]
st.success("Classification done!")
# Show table
df = pd.DataFrame({
"Label": texts,
"Probability": probs
})
df.index += 1
st.table(df)
# Most likely label
st.write(f"**Most likely label:** `{texts[probs.argmax()]}`")
# Bar chart
fig, ax = plt.subplots(figsize=(6, 4))
ax.barh(texts, probs, color="lightcoral")
ax.set_xlabel("Probability")
ax.set_xlim(0, 1)
ax.set_title("Occupation Classification Probabilities")
st.pyplot(fig)
# --- Main Content (NO FUNCTION WRAPPING) ---
st.title("Gender Bias in Occupation Detection")
st.markdown(
"""
This demo highlights potential biases in occupation-related image tagging.
For example, a woman in a lab might be labeled as a **nurse**, while a man might be labeled as a **scientist**.
Select an image and model to explore!
"""
)
st.divider()
# --- Sidebar Options ---
model_choice = st.selectbox("Select Model", ["CLIP", "BLIP"])
input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"])
image = None
# --- Image Input ---
if input_method == "Upload Image":
uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
elif input_method == "Use Image URL":
image_url = st.text_input("Paste an image URL here")
if image_url:
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception:
st.error("Couldn't load image from the provided URL.")
elif input_method == "Default Images":
default_folder = "default_images/occupation_bias"
if os.path.exists(default_folder):
default_images = sorted([f for f in os.listdir(default_folder) if f.lower().endswith((".jpg", ".jpeg", ".png"))])
if default_images:
st.subheader("Choose a default image:")
cols = st.columns(3) # 3 images per row
# Initialize session state to store selected image
if "selected_image" not in st.session_state:
st.session_state.selected_image = None
for idx, img_file in enumerate(default_images):
img_path = os.path.join(default_folder, img_file)
img = Image.open(img_path)
with cols[idx % 3]:
st.image(img, caption=img_file, use_container_width=True)
if st.button(f"Select {chr(65+idx)}", key=f"select_button_{idx}"):
st.session_state.selected_image = img_file
if st.session_state.selected_image:
image_path = os.path.join(default_folder, st.session_state.selected_image)
image = Image.open(image_path).convert("RGB")
st.success(f"Selected image: {st.session_state.selected_image}")
else:
st.warning("No default images found in the occupation_bias folder.")
else:
st.warning(f"Folder '{default_folder}' does not exist.")
st.divider()
# --- Display Image and Analyze Button ---
if image is not None:
st.image(image, caption="Input Image", use_container_width=True)
if st.button("Analyze Image"):
if model_choice == "BLIP":
analyze_with_blip(image)
elif model_choice == "CLIP":
analyze_with_clip(image)
else:
st.info("Upload or select an image to get started.")