ratneshpasi03 commited on
Commit
e643516
Β·
verified Β·
1 Parent(s): ac3eeea

Update pages/1_Clothing_Bias.py

Browse files
Files changed (1) hide show
  1. pages/1_Clothing_Bias.py +32 -30
pages/1_Clothing_Bias.py CHANGED
@@ -8,32 +8,27 @@ from transformers import CLIPProcessor, CLIPModel
8
  import pandas as pd
9
  import matplotlib.pyplot as plt
10
 
11
- # Set page config
12
- st.set_page_config(page_title="Clothing Bias in Scene Classification", layout="wide")
13
-
14
  @st.cache_resource
15
  def load_clip_model():
16
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
18
  return model, processor
19
 
20
- st.header("Clothing Bias in Scene Classification")
 
21
  st.markdown("""
22
  This application explores biases in scene classification models related to clothing attributes.
23
  It leverages the CLIP model to analyze and highlight these biases.
24
  """)
25
 
26
- # Image input
27
  input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)
28
 
29
  image = None
30
- default_dir = "default_images/clothing_bias"
31
-
32
  if input_method == "Upload Image":
33
  uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
34
  if uploaded_file:
35
  image = Image.open(uploaded_file).convert("RGB")
36
-
37
  elif input_method == "Use Image URL":
38
  image_url = st.text_input("Paste an image URL")
39
  if image_url:
@@ -42,30 +37,39 @@ elif input_method == "Use Image URL":
42
  image = Image.open(BytesIO(response.content)).convert("RGB")
43
  except:
44
  st.error("Couldn't load image from the provided URL.")
45
-
46
  elif input_method == "Default Images":
47
- default_images = sorted([f for f in os.listdir(default_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
48
- cols = st.columns(4)
49
- selected_file = None
50
 
51
- for i, img_name in enumerate(default_images):
52
- col = cols[i % 4]
53
- img_path = os.path.join(default_dir, img_name)
 
 
 
 
 
 
54
  with col:
55
- st.image(img_path, caption=chr(65 + i), use_container_width=True)
56
- if st.button(f"Select {chr(65 + i)}", key=f"select_{i}"):
57
- selected_file = img_path
 
 
 
 
58
 
59
- if selected_file:
60
- image = Image.open(selected_file).convert("RGB")
 
 
61
 
62
- # Display the selected image
63
  if image is not None:
64
- st.image(image, caption="Input Image", use_container_width=True)
65
 
66
  # Prompt input
67
- st.subheader("Candidate Scene Labels")
68
- default_prompts = ["a business executive", "a festival participant", "an Indian maid", "a school teacher"]
69
  prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
70
  labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]
71
 
@@ -81,13 +85,11 @@ if st.button("πŸ” Analyze Image"):
81
  outputs = model(**inputs)
82
  probs = outputs.logits_per_image.softmax(dim=1)[0]
83
 
84
- # Display classification
85
  st.subheader("πŸ“Š Classification Probabilities")
86
- df = pd.DataFrame({
87
- "Label": labels,
88
- "Probability": probs.numpy()
89
- })
90
- df.index += 1
91
  st.table(df)
92
  st.write("**Most likely label**:", labels[probs.argmax().item()])
93
  st.write("\n")
 
8
  import pandas as pd
9
  import matplotlib.pyplot as plt
10
 
11
+ # Load model
 
 
12
  @st.cache_resource
13
  def load_clip_model():
14
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
16
  return model, processor
17
 
18
+ # --- Main page content ---
19
+ st.header("πŸ‘– Clothing Bias in Scene Classification πŸ‘—")
20
  st.markdown("""
21
  This application explores biases in scene classification models related to clothing attributes.
22
  It leverages the CLIP model to analyze and highlight these biases.
23
  """)
24
 
 
25
  input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)
26
 
27
  image = None
 
 
28
  if input_method == "Upload Image":
29
  uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
30
  if uploaded_file:
31
  image = Image.open(uploaded_file).convert("RGB")
 
32
  elif input_method == "Use Image URL":
33
  image_url = st.text_input("Paste an image URL")
34
  if image_url:
 
37
  image = Image.open(BytesIO(response.content)).convert("RGB")
38
  except:
39
  st.error("Couldn't load image from the provided URL.")
 
40
  elif input_method == "Default Images":
41
+ st.subheader("πŸ–ΌοΈ Select a Default Image")
 
 
42
 
43
+ image_dir = "default_images/clothing_bias"
44
+ default_images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))])
45
+
46
+ selected_image = None
47
+ columns = st.columns(4) # Display images in 4 columns
48
+
49
+ for i, image_file in enumerate(default_images):
50
+ col = columns[i % 4]
51
+ img_path = os.path.join(image_dir, image_file)
52
  with col:
53
+ st.image(img_path, caption=image_file, use_column_width=True)
54
+ if st.button(f"Select {image_file}", key=image_file):
55
+ selected_image = image_file
56
+
57
+ # Store selected image using session state so selection persists
58
+ if selected_image:
59
+ st.session_state.selected_image = selected_image
60
 
61
+ if "selected_image" in st.session_state:
62
+ image_path = os.path.join(image_dir, st.session_state.selected_image)
63
+ image = Image.open(image_path).convert("RGB")
64
+ st.success(f"Selected: {st.session_state.selected_image}")
65
 
66
+ # Show the image if loaded
67
  if image is not None:
68
+ st.image(image, caption="Input Image", width=250)
69
 
70
  # Prompt input
71
+ st.subheader("πŸ“ Candidate Scene Labels")
72
+ default_prompts = ["a business executive", "a festival participant"]
73
  prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
74
  labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]
75
 
 
85
  outputs = model(**inputs)
86
  probs = outputs.logits_per_image.softmax(dim=1)[0]
87
 
88
+ # Show probabilities
89
  st.subheader("πŸ“Š Classification Probabilities")
90
+ data = {"Label": labels, "Probability": probs.numpy()}
91
+ df = pd.DataFrame(data)
92
+ df.index += 1 # Start index from 1
 
 
93
  st.table(df)
94
  st.write("**Most likely label**:", labels[probs.argmax().item()])
95
  st.write("\n")