ratneshpasi03 commited on
Commit
143189f
Β·
verified Β·
1 Parent(s): b5fbbc7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +6 -0
  2. pages/1_Clothing_Bias.py +90 -0
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(page_title="Vision Bias App", layout="wide")
4
+ st.title("The AI is **racist** too !")
5
+
6
+ st.write("Use the sidebar to navigate through different Domians.")
pages/1_Clothing_Bias.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import os
6
+ import torch
7
+ from transformers import CLIPProcessor, CLIPModel
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Set page config (optional)
12
+ st.set_page_config(page_title="Clothing Bias in Scene Classification", layout="wide")
13
+
14
+ # Load model once
15
+ @st.cache_resource
16
+ def load_clip_model():
17
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
18
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
19
+ return model, processor
20
+
21
+
22
+ # Page title
23
+ st.header("πŸ‘– Clothing Bias in Scene Classification πŸ‘—")
24
+ st.markdown("""
25
+ This application explores biases in scene classification models related to clothing attributes.
26
+ It leverages the CLIP model to analyze and highlight these biases.
27
+ """)
28
+
29
+ # Image input
30
+ input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)
31
+
32
+ image = None
33
+ if input_method == "Upload Image":
34
+ uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
35
+ if uploaded_file:
36
+ image = Image.open(uploaded_file).convert("RGB")
37
+ elif input_method == "Use Image URL":
38
+ image_url = st.text_input("Paste an image URL")
39
+ if image_url:
40
+ try:
41
+ response = requests.get(image_url)
42
+ image = Image.open(BytesIO(response.content)).convert("RGB")
43
+ except:
44
+ st.error("Couldn't load image from the provided URL.")
45
+ elif input_method == "Default Images":
46
+ default_dir = "default_images/clothing_bias"
47
+ default_images = sorted([f for f in os.listdir(default_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
48
+ default_choice = st.selectbox("Choose from default images", default_images, index=0)
49
+ image = Image.open(os.path.join(default_dir, default_choice)).convert("RGB")
50
+
51
+ # Display the image
52
+ if image is not None:
53
+ st.image(image, caption="Input Image", width=250)
54
+
55
+ # Prompt input
56
+ st.subheader("πŸ“ Candidate Scene Labels")
57
+ default_prompts = ["a business executive", "a festival participant"]
58
+ prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
59
+ labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]
60
+
61
+ # Process and classify
62
+ if st.button("πŸ” Analyze Image"):
63
+ if image is None:
64
+ st.warning("⚠️ Please upload an image, paste a URL, or choose a default image before analysis.")
65
+ else:
66
+ model, processor = load_clip_model()
67
+ with st.spinner("Analyzing the image, please wait..."):
68
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
69
+ with torch.no_grad():
70
+ outputs = model(**inputs)
71
+ probs = outputs.logits_per_image.softmax(dim=1)[0]
72
+
73
+ # Display classification
74
+ st.subheader("πŸ“Š Classification Probabilities")
75
+ df = pd.DataFrame({
76
+ "Label": labels,
77
+ "Probability": probs.numpy()
78
+ })
79
+ df.index += 1
80
+ st.table(df)
81
+ st.write("**Most likely label**:", labels[probs.argmax().item()])
82
+ st.write("\n")
83
+
84
+ # Show horizontal bar plot
85
+ fig, ax = plt.subplots(figsize=(6, 4))
86
+ ax.barh(labels, probs.numpy(), color='skyblue')
87
+ ax.set_xlim(0, 1)
88
+ ax.set_xlabel("Probability")
89
+ ax.set_title("Scene Classification")
90
+ st.pyplot(fig)