Spaces:
Build error
Build error
| import streamlit as st | |
| import string as s | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import os | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| # Load model | |
| def load_clip_model(): | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) | |
| return model, processor | |
| # --- Main page content --- | |
| st.header("Clothing Bias in Scene Classification") | |
| st.markdown(""" | |
| This application explores biases in scene classification models related to clothing attributes. | |
| It leverages the CLIP model to analyze and highlight these biases. | |
| """) | |
| input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0) | |
| image = None | |
| if input_method == "Upload Image": | |
| uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| elif input_method == "Use Image URL": | |
| image_url = st.text_input("Paste an image URL") | |
| if image_url: | |
| try: | |
| response = requests.get(image_url) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| except: | |
| st.error("Couldn't load image from the provided URL.") | |
| elif input_method == "Default Images": | |
| st.subheader("Select a Default Image") | |
| image_dir = "default_images/clothing_bias" | |
| default_images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]) | |
| selected_image = None | |
| columns = st.columns(4) # Display images in 4 columns | |
| for i, image_file in enumerate(default_images): | |
| col = columns[i % 4] | |
| img_path = os.path.join(image_dir, image_file) | |
| with col: | |
| st.image(img_path, caption=image_file, use_container_width=True) | |
| if st.button(f"Select {s.ascii_uppercase[i]}", key=image_file): | |
| selected_image = image_file | |
| # Store selected image using session state so selection persists | |
| if selected_image: | |
| st.session_state.selected_image = selected_image | |
| if "selected_image" in st.session_state: | |
| image_path = os.path.join(image_dir, st.session_state.selected_image) | |
| image = Image.open(image_path).convert("RGB") | |
| st.success(f"Selected: {st.session_state.selected_image}") | |
| # Show the image if loaded | |
| if image is not None: | |
| st.image(image, caption="Input Image", width=250) | |
| # Prompt input | |
| st.subheader("Candidate Scene Labels") | |
| default_prompts = ["a business executive", "a festival participant", "a indian maid", "a school teacher"] | |
| prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts)) | |
| labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()] | |
| # Analyze button | |
| if st.button("Analyze Image"): | |
| if image is None: | |
| st.warning("Please upload an image, paste a URL, or choose a default image before analysis.") | |
| else: | |
| model, processor = load_clip_model() | |
| with st.spinner("Analyzing the image, please wait..."): | |
| inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1)[0] | |
| # Show probabilities | |
| st.subheader("Classification Probabilities") | |
| data = {"Label": labels, "Probability": probs.numpy()} | |
| df = pd.DataFrame(data) | |
| df.index += 1 # Start index from 1 | |
| st.table(df) | |
| st.write("**Most likely label**:", labels[probs.argmax().item()]) | |
| st.write("\n") | |
| # Bar plot | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.barh(labels, probs.numpy(), color='skyblue') | |
| ax.set_xlim(0, 1) | |
| ax.set_xlabel("Probability") | |
| ax.set_title("Scene Classification") | |
| st.pyplot(fig) | |