Satoshi2077 commited on
Commit
9da140c
·
verified ·
1 Parent(s): 0572872

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
+ from functools import lru_cache
4
+ from random import randint
5
+ from typing import Any, Callable, Dict, List, Tuple
6
+
7
+ import clip
8
+ import cv2
9
+ import gradio as gr
10
+ import numpy as np
11
+ import PIL
12
+ import torch
13
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
+
15
+ CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
16
+ CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
17
+ CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
18
+ MODEL_TYPE = "default"
19
+ MAX_WIDTH = MAX_HEIGHT = 1024
20
+ TOP_K_OBJ = 100
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+
24
+ @lru_cache
25
+ def load_mask_generator() -> SamAutomaticMaskGenerator:
26
+ if not os.path.exists(CHECKPOINT_PATH):
27
+ os.makedirs(CHECKPOINT_PATH)
28
+ checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
29
+ if not os.path.exists(checkpoint):
30
+ urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint)
31
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint).to(device)
32
+ mask_generator = SamAutomaticMaskGenerator(sam)
33
+ return mask_generator
34
+
35
+
36
+ @lru_cache
37
+ def load_clip(
38
+ name: str = "ViT-B/32",
39
+ ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
40
+ model, preprocess = clip.load(name, device=device)
41
+ return model.to(device), preprocess
42
+
43
+
44
+ def adjust_image_size(image: np.ndarray) -> np.ndarray:
45
+ height, width = image.shape[:2]
46
+ if height > width:
47
+ if height > MAX_HEIGHT:
48
+ height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
49
+ else:
50
+ if width > MAX_WIDTH:
51
+ height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
52
+ image = cv2.resize(image, (width, height))
53
+ return image
54
+
55
+
56
+ @torch.no_grad()
57
+ def get_score(crop: PIL.Image.Image, texts: List[str]) -> torch.Tensor:
58
+ model, preprocess = load_clip()
59
+ preprocessed = preprocess(crop).unsqueeze(0).to(device)
60
+ tokens = clip.tokenize(texts).to(device)
61
+ logits_per_image, _ = model(preprocessed, tokens)
62
+ similarity = logits_per_image.softmax(-1).cpu()
63
+ return similarity[0, 0]
64
+
65
+
66
+ def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
67
+ x, y, w, h = mask["bbox"]
68
+ masked = image * np.expand_dims(mask["segmentation"], -1)
69
+ crop = masked[y : y + h, x : x + w]
70
+ if h > w:
71
+ top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
72
+ else:
73
+ top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0
74
+ # padding
75
+ crop = cv2.copyMakeBorder(
76
+ crop,
77
+ top,
78
+ bottom,
79
+ left,
80
+ right,
81
+ cv2.BORDER_CONSTANT,
82
+ value=(0, 0, 0),
83
+ )
84
+ crop = PIL.Image.fromarray(crop)
85
+ return crop
86
+
87
+
88
+ def get_texts(query: str) -> List[str]:
89
+ return [f"a picture of {query}", "a picture of background"]
90
+
91
+
92
+ def filter_masks(
93
+ image: np.ndarray,
94
+ masks: List[Dict[str, Any]],
95
+ predicted_iou_threshold: float,
96
+ stability_score_threshold: float,
97
+ query: str,
98
+ clip_threshold: float,
99
+ ) -> List[Dict[str, Any]]:
100
+ filtered_masks: List[Dict[str, Any]] = []
101
+
102
+ for mask in sorted(masks, key=lambda mask: mask["area"])[-TOP_K_OBJ:]:
103
+ if (
104
+ mask["predicted_iou"] < predicted_iou_threshold
105
+ or mask["stability_score"] < stability_score_threshold
106
+ or image.shape[:2] != mask["segmentation"].shape[:2]
107
+ or query
108
+ and get_score(crop_image(image, mask), get_texts(query)) < clip_threshold
109
+ ):
110
+ continue
111
+
112
+ filtered_masks.append(mask)
113
+
114
+ return filtered_masks
115
+
116
+
117
+ def draw_masks(
118
+ image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
119
+ ) -> np.ndarray:
120
+ for mask in masks:
121
+ color = [randint(127, 255) for _ in range(3)]
122
+
123
+ # draw mask overlay
124
+ colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
125
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
126
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
127
+ image_overlay = masked.filled()
128
+ image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
129
+
130
+ # draw contour
131
+ contours, _ = cv2.findContours(
132
+ np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
133
+ )
134
+ cv2.drawContours(image, contours, -1, (0, 0, 255), 2)
135
+ return image
136
+
137
+
138
+ def segment(
139
+ predicted_iou_threshold: float,
140
+ stability_score_threshold: float,
141
+ clip_threshold: float,
142
+ image_path: str,
143
+ query: str,
144
+ ) -> PIL.ImageFile.ImageFile:
145
+ mask_generator = load_mask_generator()
146
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
147
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
148
+
149
+ # reduce the size to save gpu memory
150
+ image = adjust_image_size(image)
151
+ print(image.shape)
152
+ masks = mask_generator.generate(image)
153
+ # print(masks)
154
+ masks = filter_masks(
155
+ image,
156
+ masks,
157
+ predicted_iou_threshold,
158
+ stability_score_threshold,
159
+ query,
160
+ clip_threshold,
161
+ )
162
+ image = draw_masks(image, masks)
163
+ image = PIL.Image.fromarray(image)
164
+ return image
165
+
166
+
167
+ demo = gr.Interface(
168
+ fn=segment,
169
+ inputs=[
170
+ gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"),
171
+ gr.Slider(0, 1, value=0.8, label="stability_score_threshold"),
172
+ gr.Slider(0, 1, value=0.85, label="clip_threshold"),
173
+ gr.Image(type="filepath"),
174
+ "text",
175
+ ],
176
+ outputs="image",
177
+ allow_flagging="never",
178
+ title="Segment Anything with CLIP",
179
+ examples=[
180
+ [
181
+ 0.9,
182
+ 0.8,
183
+ 0.99,
184
+ os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
185
+ "dog",
186
+ ],
187
+ [
188
+ 0.9,
189
+ 0.8,
190
+ 0.75,
191
+ os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
192
+ "building",
193
+ ],
194
+ [
195
+ 0.9,
196
+ 0.8,
197
+ 0.998,
198
+ os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
199
+ "strawberry",
200
+ ],
201
+ [
202
+ 0.9,
203
+ 0.8,
204
+ 0.75,
205
+ os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
206
+ "horse",
207
+ ],
208
+ [
209
+ 0.9,
210
+ 0.8,
211
+ 0.99,
212
+ os.path.join(os.path.dirname(__file__), "examples/bears.jpg"),
213
+ "bear",
214
+ ],
215
+ [
216
+ 0.9,
217
+ 0.8,
218
+ 0.99,
219
+ os.path.join(os.path.dirname(__file__), "examples/cats.jpg"),
220
+ "cat",
221
+ ],
222
+ [
223
+ 0.9,
224
+ 0.8,
225
+ 0.99,
226
+ os.path.join(os.path.dirname(__file__), "examples/fish.jpg"),
227
+ "fish",
228
+ ],
229
+ ],
230
+ )
231
+
232
+ if __name__ == "__main__":
233
+ demo.launch(share=True)