File size: 4,033 Bytes
4548917
ed6edb0
4548917
 
 
 
 
 
 
 
 
e643516
4548917
 
 
 
 
 
e643516
ed6edb0
4548917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d99e0b0
4548917
e643516
 
 
 
 
 
 
 
 
ac3eeea
ed6edb0
 
e643516
 
 
 
 
ba20e08
e643516
 
 
 
ba20e08
e643516
4548917
e643516
4548917
ac3eeea
d99e0b0
ed6edb0
4548917
 
 
ac3eeea
d99e0b0
4548917
d99e0b0
4548917
 
 
 
 
 
 
 
e643516
d99e0b0
e643516
 
 
4548917
 
ac3eeea
4548917
ac3eeea
4548917
 
 
 
 
ba20e08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import string as s
from PIL import Image
import requests
from io import BytesIO
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import matplotlib.pyplot as plt

# Load model
@st.cache_resource
def load_clip_model():
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
    return model, processor

# --- Main page content ---
st.header("Clothing Bias in Scene Classification")
st.markdown("""
This application explores biases in scene classification models related to clothing attributes.
It leverages the CLIP model to analyze and highlight these biases.
""")

input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)

image = None
if input_method == "Upload Image":
    uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
    if uploaded_file:
        image = Image.open(uploaded_file).convert("RGB")
elif input_method == "Use Image URL":
    image_url = st.text_input("Paste an image URL")
    if image_url:
        try:
            response = requests.get(image_url)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        except:
            st.error("Couldn't load image from the provided URL.")
elif input_method == "Default Images":
    st.subheader("Select a Default Image")

    image_dir = "default_images/clothing_bias"
    default_images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))])

    selected_image = None
    columns = st.columns(4)  # Display images in 4 columns

    for i, image_file in enumerate(default_images):
        col = columns[i % 4]
        img_path = os.path.join(image_dir, image_file)
        with col:
            st.image(img_path, caption=image_file, use_container_width=True)
            if st.button(f"Select {s.ascii_uppercase[i]}", key=image_file):
                selected_image = image_file

    # Store selected image using session state so selection persists
    if selected_image:
        st.session_state.selected_image = selected_image

    if "selected_image" in st.session_state:
        image_path = os.path.join(image_dir, st.session_state.selected_image)
        image = Image.open(image_path).convert("RGB")
        st.success(f"Selected: {st.session_state.selected_image}")

# Show the image if loaded
if image is not None:
    st.image(image, caption="Input Image", width=250)

# Prompt input
st.subheader("Candidate Scene Labels")
default_prompts = ["a business executive", "a festival participant", "a indian maid", "a school teacher"]
prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]

# Analyze button
if st.button("Analyze Image"):
    if image is None:
        st.warning("Please upload an image, paste a URL, or choose a default image before analysis.")
    else:
        model, processor = load_clip_model()
        with st.spinner("Analyzing the image, please wait..."):
            inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
                probs = outputs.logits_per_image.softmax(dim=1)[0]

        # Show probabilities
        st.subheader("Classification Probabilities")
        data = {"Label": labels, "Probability": probs.numpy()}
        df = pd.DataFrame(data)
        df.index += 1  # Start index from 1
        st.table(df)
        st.write("**Most likely label**:", labels[probs.argmax().item()])
        st.write("\n")

        # Bar plot
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.barh(labels, probs.numpy(), color='skyblue')
        ax.set_xlim(0, 1)
        ax.set_xlabel("Probability")
        ax.set_title("Scene Classification")
        st.pyplot(fig)