Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| from io import BytesIO | |
| from PIL import Image as PILIMAGE | |
| #from IPython.display import Image | |
| #from IPython.core.display import HTML | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
| import os | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| def load_class_names(dataset_path=''): | |
| names = {} | |
| with open(os.path.join(dataset_path, 'classes.txt')) as f: | |
| for line in f: | |
| pieces = line.strip().split() | |
| class_id = pieces[0] | |
| names[class_id] = ' '.join(pieces[1:]) | |
| return names | |
| def get_labels(): | |
| labels = [] | |
| class_names = load_class_names(".") | |
| for _, name in class_names.items(): | |
| labels.append(f"This is a photo of {name}.") | |
| return labels | |
| def encode_text(text): | |
| with torch.no_grad(): | |
| inputs = tokenizer([text], padding=True, return_tensors="pt") | |
| text_encoded = model.get_text_features(**inputs).detach().numpy() | |
| return text_encoded | |
| ALL_LABELS = get_labels() | |
| try: | |
| LABEL_FEATURES = np.load("label_features.np") | |
| except: | |
| LABEL_FEATURES = [] | |
| for label in ALL_LABELS: | |
| LABEL_FEATURES.append(encode_text(label)) | |
| LABEL_FEATURES = np.vstack(LABEL_FEATURES) | |
| np.save(open("label_features.np", "wb"), LABEL_FEATURES) | |
| def encode_image(image): | |
| image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') | |
| with torch.no_grad(): | |
| photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] | |
| search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) | |
| search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) | |
| image_encoded = search_photo_feature.cpu().numpy() | |
| return image_encoded | |
| def similarity(feature, label_features): | |
| similarities = list((feature @ label_features.T).squeeze(0)) | |
| return similarities | |
| def find_best_matches(image): | |
| image_features = encode_image(image) | |
| similarities = similarity(image_features, LABEL_FEATURES) | |
| best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True) | |
| idx = best_spec[0][1] | |
| label = ALL_LABELS[idx] | |
| return label | |
| examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']] | |
| gr.Interface(fn=find_best_matches, | |
| inputs=[ | |
| gr.inputs.Image(label="Image to classify", optional=False), | |
| ], | |
| examples=examples, | |
| theme="grass", | |
| outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier", | |
| description="This application can classify North American Birds.").launch() | |