import streamlit as st from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel from PIL import Image import requests import os import pandas as pd import matplotlib.pyplot as plt from io import BytesIO # --- Caching models --- @st.cache_resource def load_blip_model(): processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") return processor, model @st.cache_resource def load_clip_model(): processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") return processor, model # --- Analysis Functions --- def analyze_with_blip(image): st.subheader("BLIP Caption Generation") processor, model = load_blip_model() with st.spinner("Generating caption..."): inputs = processor(image, return_tensors="pt") out = model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) st.success("Caption generated!") st.write(f"**Caption:** {caption}") def analyze_with_clip(image): st.subheader("CLIP Occupation Classification") processor, model = load_clip_model() texts = ["a nurse", "a doctor", "a scientist", "an engineer"] with st.spinner("Classifying image..."): inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1).detach().numpy()[0] st.success("Classification done!") # Show table df = pd.DataFrame({ "Label": texts, "Probability": probs }) df.index += 1 st.table(df) # Most likely label st.write(f"**Most likely label:** `{texts[probs.argmax()]}`") # Bar chart fig, ax = plt.subplots(figsize=(6, 4)) ax.barh(texts, probs, color="lightcoral") ax.set_xlabel("Probability") ax.set_xlim(0, 1) ax.set_title("Occupation Classification Probabilities") st.pyplot(fig) # --- Main Content (NO FUNCTION WRAPPING) --- st.title("Gender Bias in Occupation Detection") st.markdown( """ This demo highlights potential biases in occupation-related image tagging. For example, a woman in a lab might be labeled as a **nurse**, while a man might be labeled as a **scientist**. Select an image and model to explore! """ ) st.divider() # --- Sidebar Options --- model_choice = st.selectbox("Select Model", ["CLIP", "BLIP"]) input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"]) image = None # --- Image Input --- if input_method == "Upload Image": uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") elif input_method == "Use Image URL": image_url = st.text_input("Paste an image URL here") if image_url: try: response = requests.get(image_url) image = Image.open(BytesIO(response.content)).convert("RGB") except Exception: st.error("Couldn't load image from the provided URL.") elif input_method == "Default Images": default_folder = "default_images/occupation_bias" if os.path.exists(default_folder): default_images = sorted([f for f in os.listdir(default_folder) if f.lower().endswith((".jpg", ".jpeg", ".png"))]) if default_images: st.subheader("Choose a default image:") cols = st.columns(3) # 3 images per row # Initialize session state to store selected image if "selected_image" not in st.session_state: st.session_state.selected_image = None for idx, img_file in enumerate(default_images): img_path = os.path.join(default_folder, img_file) img = Image.open(img_path) with cols[idx % 3]: st.image(img, caption=img_file, use_container_width=True) if st.button(f"Select {chr(65+idx)}", key=f"select_button_{idx}"): st.session_state.selected_image = img_file if st.session_state.selected_image: image_path = os.path.join(default_folder, st.session_state.selected_image) image = Image.open(image_path).convert("RGB") st.success(f"Selected image: {st.session_state.selected_image}") else: st.warning("No default images found in the occupation_bias folder.") else: st.warning(f"Folder '{default_folder}' does not exist.") st.divider() # --- Display Image and Analyze Button --- if image is not None: st.image(image, caption="Input Image", use_container_width=True) if st.button("Analyze Image"): if model_choice == "BLIP": analyze_with_blip(image) elif model_choice == "CLIP": analyze_with_clip(image) else: st.info("Upload or select an image to get started.")