Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import requests | |
| import os | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| # --- Caching models --- | |
| def load_blip_model(): | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| return processor, model | |
| def load_clip_model(): | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| return processor, model | |
| # --- Analysis Functions --- | |
| def analyze_with_blip(image): | |
| st.subheader("BLIP Caption Generation") | |
| processor, model = load_blip_model() | |
| with st.spinner("Generating caption..."): | |
| inputs = processor(image, return_tensors="pt") | |
| out = model.generate(**inputs) | |
| caption = processor.decode(out[0], skip_special_tokens=True) | |
| st.success("Caption generated!") | |
| st.write(f"**Caption:** {caption}") | |
| def analyze_with_clip(image): | |
| st.subheader("CLIP Occupation Classification") | |
| processor, model = load_clip_model() | |
| texts = ["a nurse", "a doctor", "a scientist", "an engineer"] | |
| with st.spinner("Classifying image..."): | |
| inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1).detach().numpy()[0] | |
| st.success("Classification done!") | |
| # Show table | |
| df = pd.DataFrame({ | |
| "Label": texts, | |
| "Probability": probs | |
| }) | |
| df.index += 1 | |
| st.table(df) | |
| # Most likely label | |
| st.write(f"**Most likely label:** `{texts[probs.argmax()]}`") | |
| # Bar chart | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.barh(texts, probs, color="lightcoral") | |
| ax.set_xlabel("Probability") | |
| ax.set_xlim(0, 1) | |
| ax.set_title("Occupation Classification Probabilities") | |
| st.pyplot(fig) | |
| # --- Main Content (NO FUNCTION WRAPPING) --- | |
| st.title("Gender Bias in Occupation Detection") | |
| st.markdown( | |
| """ | |
| This demo highlights potential biases in occupation-related image tagging. | |
| For example, a woman in a lab might be labeled as a **nurse**, while a man might be labeled as a **scientist**. | |
| Select an image and model to explore! | |
| """ | |
| ) | |
| st.divider() | |
| # --- Sidebar Options --- | |
| model_choice = st.selectbox("Select Model", ["CLIP", "BLIP"]) | |
| input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"]) | |
| image = None | |
| # --- Image Input --- | |
| if input_method == "Upload Image": | |
| uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| elif input_method == "Use Image URL": | |
| image_url = st.text_input("Paste an image URL here") | |
| if image_url: | |
| try: | |
| response = requests.get(image_url) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception: | |
| st.error("Couldn't load image from the provided URL.") | |
| elif input_method == "Default Images": | |
| default_folder = "default_images/occupation_bias" | |
| if os.path.exists(default_folder): | |
| default_images = sorted([f for f in os.listdir(default_folder) if f.lower().endswith((".jpg", ".jpeg", ".png"))]) | |
| if default_images: | |
| st.subheader("Choose a default image:") | |
| cols = st.columns(3) # 3 images per row | |
| # Initialize session state to store selected image | |
| if "selected_image" not in st.session_state: | |
| st.session_state.selected_image = None | |
| for idx, img_file in enumerate(default_images): | |
| img_path = os.path.join(default_folder, img_file) | |
| img = Image.open(img_path) | |
| with cols[idx % 3]: | |
| st.image(img, caption=img_file, use_container_width=True) | |
| if st.button(f"Select {chr(65+idx)}", key=f"select_button_{idx}"): | |
| st.session_state.selected_image = img_file | |
| if st.session_state.selected_image: | |
| image_path = os.path.join(default_folder, st.session_state.selected_image) | |
| image = Image.open(image_path).convert("RGB") | |
| st.success(f"Selected image: {st.session_state.selected_image}") | |
| else: | |
| st.warning("No default images found in the occupation_bias folder.") | |
| else: | |
| st.warning(f"Folder '{default_folder}' does not exist.") | |
| st.divider() | |
| # --- Display Image and Analyze Button --- | |
| if image is not None: | |
| st.image(image, caption="Input Image", use_container_width=True) | |
| if st.button("Analyze Image"): | |
| if model_choice == "BLIP": | |
| analyze_with_blip(image) | |
| elif model_choice == "CLIP": | |
| analyze_with_clip(image) | |
| else: | |
| st.info("Upload or select an image to get started.") | |