Johnny-Z commited on
Commit
7126f84
·
verified ·
1 Parent(s): f8efae2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +473 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPImageProcessor, AutoModel
2
+ import torch
3
+ import json
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import os
8
+ import faiss
9
+ import time
10
+ import requests
11
+ from huggingface_hub import login, snapshot_download
12
+
13
+ TITLE = "Danbooru Tagger"
14
+ DESCRIPTION = """
15
+ ## Dataset
16
+ - Source: Cleaned Danbooru
17
+
18
+ ## Metrics
19
+ - Validation Split: 10% of Dataset
20
+ - Validation Results:
21
+
22
+ ### General
23
+ | Metric | Value |
24
+ |-----------------|-------------|
25
+ | Macro F1 | 0.4678 |
26
+ | Macro Precision | 0.4605 |
27
+ | Macro Recall | 0.5229 |
28
+ | Micro F1 | 0.6661 |
29
+ | Micro Precision | 0.6049 |
30
+ | Micro Recall | 0.7411 |
31
+
32
+ ### Character
33
+ | Metric | Value |
34
+ |-----------------|-------------|
35
+ | Macro F1 | 0.8925 |
36
+ | Macro Precision | 0.9099 |
37
+ | Macro Recall | 0.8935 |
38
+ | Micro F1 | 0.9232 |
39
+ | Micro Precision | 0.9264 |
40
+ | Micro Recall | 0.9199 |
41
+
42
+ ### Artist
43
+ | Metric | Value |
44
+ |-----------------|-------------|
45
+ | Macro F1 | 0.7904 |
46
+ | Macro Precision | 0.8286 |
47
+ | Macro Recall | 0.7904 |
48
+ | Micro F1 | 0.5989 |
49
+ | Micro Precision | 0.5975 |
50
+ | Micro Recall | 0.6004 |
51
+ """
52
+
53
+ kaomojis = [
54
+ "0_0",
55
+ "(o)_(o)",
56
+ "+_+",
57
+ "+_-",
58
+ "._.",
59
+ "<o>_<o>",
60
+ "<|>_<|>",
61
+ "=_=",
62
+ ">_<",
63
+ "3_3",
64
+ "6_9",
65
+ ">_o",
66
+ "@_@",
67
+ "^_^",
68
+ "o_o",
69
+ "u_u",
70
+ "x_x",
71
+ "|_|",
72
+ "||_||",
73
+ ]
74
+
75
+ device = torch.device('cpu')
76
+ dtype = torch.float32
77
+
78
+ hf_token = os.getenv("HF_TOKEN")
79
+ if hf_token:
80
+ login(token=hf_token)
81
+ else:
82
+ raise ValueError("environment variable HF_TOKEN not found.")
83
+
84
+ repo = snapshot_download('Johnny-Z/vit-e4')
85
+ model = AutoModel.from_pretrained(repo, dtype=dtype, trust_remote_code=True, device_map=device)
86
+
87
+ index_dir = snapshot_download('Johnny-Z/dan_index', repo_type='dataset')
88
+
89
+ processor = CLIPImageProcessor.from_pretrained(repo)
90
+
91
+ class MultiheadAttentionPoolingHead(nn.Module):
92
+ def __init__(self, input_size):
93
+ super().__init__()
94
+
95
+ self.map_probe = nn.Parameter(torch.randn(1, 1, input_size))
96
+ self.map_layernorm0 = nn.LayerNorm(input_size, eps=1e-08)
97
+ self.map_attention = torch.nn.MultiheadAttention(input_size, input_size // 64, batch_first=True)
98
+ self.map_layernorm1 = nn.LayerNorm(input_size, eps=1e-08)
99
+ self.map_ffn = nn.Sequential(
100
+ nn.Linear(input_size, input_size * 4),
101
+ nn.SiLU(),
102
+ nn.Linear(input_size * 4, input_size)
103
+ )
104
+
105
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
106
+ batch_size = hidden_state.shape[0]
107
+ probe = self.map_probe.repeat(batch_size, 1, 1)
108
+
109
+ hidden_state = self.map_layernorm0(hidden_state)
110
+ hidden_state = self.map_attention(probe, hidden_state, hidden_state)[0]
111
+ hidden_state = self.map_layernorm1(hidden_state)
112
+
113
+ residual = hidden_state
114
+ hidden_state = residual + self.map_ffn(hidden_state)
115
+ return hidden_state[:, 0]
116
+
117
+ class MLP(nn.Module):
118
+ def __init__(self, input_size, class_num):
119
+ super().__init__()
120
+ self.mlp_layer0 = nn.Sequential(
121
+ nn.LayerNorm(input_size, eps=1e-08),
122
+ nn.Linear(input_size, input_size // 2),
123
+ nn.SiLU()
124
+ )
125
+ self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
126
+ self.sigmoid = nn.Sigmoid()
127
+
128
+ def forward(self, x):
129
+ x = self.mlp_layer0(x)
130
+ x = self.mlp_layer1(x)
131
+ x = self.sigmoid(x)
132
+ return x
133
+
134
+ class MLP_Retrieval(nn.Module):
135
+ def __init__(self, input_size, class_num):
136
+ super().__init__()
137
+ self.mlp_layer0 = nn.Sequential(
138
+ nn.Linear(input_size, input_size // 2),
139
+ nn.SiLU()
140
+ )
141
+ self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
142
+
143
+ def forward(self, x):
144
+ x = self.mlp_layer0(x)
145
+ x = self.mlp_layer1(x)
146
+ x1, x2 = x[:, :15], x[:, 15:]
147
+ x1 = torch.softmax(x1, dim=1)
148
+ x2 = torch.softmax(x2, dim=1)
149
+ x = torch.cat([x1, x2], dim=1)
150
+
151
+ return x
152
+
153
+ class MLP_R(nn.Module):
154
+ def __init__(self, input_size):
155
+ super().__init__()
156
+ self.mlp_layer0 = nn.Sequential(
157
+ nn.Linear(input_size, 256),
158
+ )
159
+
160
+ def forward(self, x):
161
+ x = self.mlp_layer0(x)
162
+ return x
163
+
164
+ with open(os.path.join(repo, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
165
+ general_dict = json.load(f)
166
+
167
+ with open(os.path.join(repo, 'character_tag_dict.json'), 'r', encoding='utf-8') as f:
168
+ character_dict = json.load(f)
169
+
170
+ with open(os.path.join(repo, 'artist_tag_dict.json'), 'r', encoding='utf-8') as f:
171
+ artist_dict = json.load(f)
172
+
173
+ with open(os.path.join(repo, 'implications_list.json'), 'r', encoding='utf-8') as f:
174
+ implications_list = json.load(f)
175
+
176
+ with open(os.path.join(repo, 'artist_threshold.json'), 'r', encoding='utf-8') as f:
177
+ artist_thresholds = json.load(f)
178
+
179
+ with open(os.path.join(repo, 'character_threshold.json'), 'r', encoding='utf-8') as f:
180
+ character_thresholds = json.load(f)
181
+
182
+ with open(os.path.join(repo, 'general_threshold.json'), 'r', encoding='utf-8') as f:
183
+ general_thresholds = json.load(f)
184
+
185
+ model_map = MultiheadAttentionPoolingHead(2048)
186
+ model_map.load_state_dict(torch.load(os.path.join(repo, "map_head.pth"), map_location=device, weights_only=True))
187
+ model_map.to(device).to(dtype).eval()
188
+
189
+ general_class = 9775
190
+ mlp_general = MLP(2048, general_class)
191
+ mlp_general.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_general.pth"), map_location=device, weights_only=True))
192
+ mlp_general.to(device).to(dtype).eval()
193
+
194
+ character_class = 7568
195
+ mlp_character = MLP(2048, character_class)
196
+ mlp_character.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_character.pth"), map_location=device, weights_only=True))
197
+ mlp_character.to(device).to(dtype).eval()
198
+
199
+ artist_class = 13957
200
+ mlp_artist = MLP(2048, artist_class)
201
+ mlp_artist.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
202
+ mlp_artist.to(device).to(dtype).eval()
203
+
204
+ mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
205
+ mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
206
+ mlp_artist_retrieval.to(device).to(dtype).eval()
207
+
208
+ mlp_r = MLP_R(2048)
209
+ mlp_r.load_state_dict(torch.load(os.path.join(repo, "retrieval_head.pth"), map_location=device, weights_only=True))
210
+ mlp_r.to(device).to(dtype).eval()
211
+
212
+ def prediction_to_tag(prediction, tag_dict, class_num):
213
+ prediction = prediction.view(class_num)
214
+ predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
215
+
216
+ general = {}
217
+ character = {}
218
+ artist = {}
219
+ date = {}
220
+ rating = {}
221
+
222
+ for tag, value in tag_dict.items():
223
+ if value[2] in predicted_ids:
224
+ tag_value = round(prediction[value[2] - 1].item(), 6)
225
+ if value[1] == "general" and tag_value >= general_thresholds.get(tag, {}).get("Threshold", 0.75):
226
+ general[tag] = tag_value
227
+ elif value[1] == "character" and tag_value >= character_thresholds.get(tag, {}).get("Threshold", 0.75):
228
+ character[tag] = tag_value
229
+ elif value[1] == "artist" and tag_value >= artist_thresholds.get(tag, {}).get("Threshold", 0.75):
230
+ artist[tag] = tag_value
231
+ elif value[1] == "rating":
232
+ rating[tag] = tag_value
233
+ elif value[1] == "date":
234
+ date[tag] = tag_value
235
+
236
+ general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
237
+ character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
238
+ artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
239
+
240
+ if date:
241
+ date = {max(date, key=date.get): date[max(date, key=date.get)]}
242
+ if rating:
243
+ rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
244
+
245
+ return general, character, artist, date, rating
246
+
247
+ def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
248
+ prediction = prediction.view(class_num)
249
+ predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
250
+
251
+ artist = {}
252
+ date = {}
253
+
254
+ for tag, value in tag_dict.items():
255
+ if value[2] in predicted_ids:
256
+ tag_value = round(prediction[value[2] - 1].item(), 6)
257
+ if value[1] == "artist":
258
+ artist[tag] = tag_value
259
+ elif value[1] == "date":
260
+ date[tag] = tag_value
261
+
262
+ artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
263
+ artist = dict(list(artist.items())[:top_k])
264
+
265
+ if date:
266
+ date = {max(date, key=date.get): date[max(date, key=date.get)]}
267
+
268
+ return artist, date
269
+
270
+ def load_id_map(id_map_path):
271
+ with open(id_map_path, "r") as f:
272
+ id_map = json.load(f)
273
+
274
+ id_map = {int(k): int(v) for k, v in id_map.items()}
275
+
276
+ inv_map = {v: k for k, v in id_map.items()}
277
+ return id_map, inv_map
278
+
279
+ def search_index(query_vector, k=32, distance_threshold_min=0, distance_threshold_max=64, nprobe=4):
280
+ global index_dir
281
+ index_path = os.path.join(index_dir, 'danbooru_retrieval.index')
282
+ id_map_path = os.path.join(index_dir, 'danbooru_retrieval_id_map.json')
283
+ distance_threshold_min = distance_threshold_min**2
284
+ distance_threshold_max = distance_threshold_max**2
285
+
286
+ index = faiss.read_index(index_path)
287
+
288
+ if nprobe is not None and hasattr(index, "nprobe"):
289
+ index.nprobe = nprobe
290
+ _, inv_map = load_id_map(id_map_path)
291
+
292
+ qv = query_vector.detach().to(torch.float32).cpu().numpy()
293
+
294
+ distances, internal_ids = index.search(qv, k)
295
+ distances = distances[0]
296
+ internal_ids = internal_ids[0]
297
+
298
+ results = []
299
+ for dist, internal_id in zip(distances, internal_ids):
300
+ if internal_id == -1:
301
+ continue
302
+ if dist < distance_threshold_min or dist > distance_threshold_max:
303
+ continue
304
+ original_id = inv_map.get(int(internal_id))
305
+ if original_id is None:
306
+ continue
307
+ results.append({"original_id": original_id, "l2_distance": float(dist**0.5)})
308
+ results.sort(key=lambda x: x["l2_distance"])
309
+
310
+ return results
311
+
312
+ def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.25, timeout=4.0):
313
+ pairs = []
314
+ for item in retrieval_results:
315
+ oid = item.get("original_id")
316
+ if oid is None:
317
+ continue
318
+ api_url = f"https://danbooru.donmai.us/posts/{oid}.json"
319
+ try:
320
+ resp = requests.get(api_url, timeout=timeout)
321
+ if resp.status_code != 200:
322
+
323
+ time.sleep(sleep_sec)
324
+ continue
325
+ data = resp.json()
326
+ url = data.get("large_file_url") or data.get("file_url") or data.get("preview_file_url")
327
+ if not url:
328
+ time.sleep(sleep_sec)
329
+ continue
330
+
331
+ if url.startswith("//"):
332
+ url = "https:" + url
333
+ elif url.startswith("/"):
334
+ url = "https://danbooru.donmai.us" + url
335
+ pairs.append((url, oid))
336
+ except Exception:
337
+
338
+ pass
339
+ finally:
340
+
341
+ time.sleep(sleep_sec)
342
+ return pairs
343
+
344
+ def process_image(image, k, distance_threshold_min, distance_threshold_max):
345
+ try:
346
+ image = image.convert('RGBA')
347
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
348
+ image = Image.alpha_composite(background, image).convert('RGB')
349
+
350
+ image_inputs = processor(images=[image], return_tensors="pt").to(device).to(dtype)
351
+
352
+ except (OSError, IOError) as e:
353
+ print(f"Error opening image: {e}")
354
+ return
355
+ with torch.no_grad():
356
+ embedding = model(image_inputs.pixel_values)
357
+
358
+ embedding = model_map(embedding)
359
+
360
+ embedding_r = mlp_r(embedding)
361
+
362
+ retrieval_results = search_index(embedding_r, k, distance_threshold_min, distance_threshold_max)
363
+
364
+ url_id_pairs = fetch_retrieval_image_urls(retrieval_results)
365
+
366
+ retrieval_gallery_items = [(url, f"https://danbooru.donmai.us/posts/{oid}") for url, oid in url_id_pairs]
367
+
368
+ general_prediction = mlp_general(embedding)
369
+ general_ = prediction_to_tag(general_prediction, general_dict, general_class)
370
+ general_tags = general_[0]
371
+ rating = general_[4]
372
+
373
+ character_prediction = mlp_character(embedding)
374
+ character_ = prediction_to_tag(character_prediction, character_dict, character_class)
375
+ character_tags = character_[1]
376
+
377
+ artist_retrieval_prediction = mlp_artist_retrieval(embedding)
378
+ artist_retrieval_ = prediction_to_retrieval(artist_retrieval_prediction, artist_dict, artist_class, 10)
379
+ artist_tags = artist_retrieval_[0]
380
+ date = artist_retrieval_[1]
381
+
382
+ combined_tags = {**general_tags}
383
+
384
+ tags_list = [tag for tag in combined_tags]
385
+ remove_list = []
386
+ for tag in tags_list:
387
+ if tag in implications_list:
388
+ for implication in implications_list[tag]:
389
+ remove_list.append(implication)
390
+ tags_list = [tag for tag in tags_list if tag not in remove_list]
391
+ tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
392
+
393
+ tags_str = ", ".join(tags_list).replace("(", r"\(").replace(")", r"\)")
394
+
395
+ return (
396
+ tags_str,
397
+ artist_tags,
398
+ character_tags,
399
+ general_tags,
400
+ rating,
401
+ date,
402
+ retrieval_gallery_items,
403
+ )
404
+
405
+ def main():
406
+ with gr.Blocks(title=TITLE) as demo:
407
+ with gr.Column():
408
+ gr.Markdown(
409
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
410
+ )
411
+ with gr.Row():
412
+ with gr.Column(variant="panel"):
413
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
414
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
415
+ k_slider = gr.Slider(1, 100, value=32, step=1, label="Top K Results")
416
+ distance_min_slider = gr.Slider(0, 128, value=0, step=1, label="Min Distance Threshold")
417
+ distance_max_slider = gr.Slider(0, 128, value=80, step=1, label="Max Distance Threshold")
418
+ with gr.Row():
419
+ clear = gr.ClearButton(
420
+ components=[
421
+ image,
422
+ k_slider,
423
+ distance_min_slider,
424
+ distance_max_slider,
425
+ ],
426
+ variant="secondary",
427
+ size="lg",
428
+ )
429
+ gr.Markdown(value=DESCRIPTION)
430
+ with gr.Column(variant="panel"):
431
+ tags_str = gr.Textbox(label="Output", lines=4)
432
+ with gr.Row():
433
+ rating = gr.Label(label="Rating")
434
+ date = gr.Label(label="Year")
435
+ artist_tags = gr.Label(label="Artist")
436
+ character_tags = gr.Label(label="Character")
437
+ general_tags = gr.Label(label="General")
438
+ with gr.Row():
439
+ retrieval_gallery = gr.Gallery(
440
+ label="Retrieval Preview",
441
+ columns=5,
442
+ )
443
+ clear.add(
444
+ [
445
+ tags_str,
446
+ artist_tags,
447
+ general_tags,
448
+ character_tags,
449
+ rating,
450
+ date,
451
+ retrieval_gallery,
452
+ ]
453
+ )
454
+
455
+ submit.click(
456
+ process_image,
457
+ inputs=[image, k_slider, distance_min_slider, distance_max_slider],
458
+ outputs=[
459
+ tags_str,
460
+ artist_tags,
461
+ character_tags,
462
+ general_tags,
463
+ rating,
464
+ date,
465
+ retrieval_gallery,
466
+ ],
467
+ )
468
+
469
+ demo.queue(max_size=10)
470
+ demo.launch()
471
+
472
+ if __name__ == "__main__":
473
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio
5
+ einops
6
+ timm
7
+ accelerate
8
+ faiss-cpu