Spaces:
Runtime error
Runtime error
File size: 7,187 Bytes
c6b36b9 da39995 79ec174 08d46f7 c6b36b9 f300457 dfda494 ac0be08 08d46f7 f300457 c6b36b9 f300457 79ec174 f300457 c6b36b9 f300457 a4d0555 c6b36b9 f300457 ac0be08 7a5e372 ac0be08 c6b36b9 f300457 a4d0555 c6b36b9 f300457 c6b36b9 a4d0555 f300457 c6b36b9 f300457 c6b36b9 f300457 c6b36b9 f300457 c6b36b9 f300457 c6b36b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import gradio as gr
import json
import os
import torch
from torchvision import transforms
from PIL import Image
import folium
import base64
import glob
import warnings
from datasets import load_dataset
import io
from zipfile import ZipFile
warnings.filterwarnings("ignore", category=FutureWarning)
# Load dataset from Hugging Face
dataset = load_dataset("hangunwoo07/Naturing_Bird_Data")
dataset = dataset['train'] # Access the train split directly
# Load bird data
with open('DB/bird_data.json', 'r', encoding='utf-8') as f:
bird_data = json.load(f)
# Load model and classes
model_ft = torch.load('bird_detection_model.pth', map_location=torch.device('cpu'))
model_ft.eval()
with open('DB/class_names.json', 'r', encoding='utf-8') as f:
classes = json.load(f)
# Constants
DSHS_LOCATION = [36.373719, 127.370415]
def create_image_popup(bird_name, image_filename):
target_image = '/'.join(bird_name, image_filename.split())
row_index = metadata[metadata['image'] == target_image].index[0] - 1
try:
with open(dataset['image'][row_index], 'rb') as image_file:
encoded = base64.b64encode(image_file.read()).decode()
return f'<img src="data:image/jpeg;base64,{encoded}" width="200px">'
except:
return ''
def create_map():
"""
Create an interactive map with bird sightings
"""
m = folium.Map(location=DSHS_LOCATION, zoom_start=15)
dshs_popup_content = f"""
<div>
<h4>대전과학고등학교</h4>
</div>
"""
# 대전과학고 마커 추가
folium.Marker(
DSHS_LOCATION,
popup=folium.Popup(dshs_popup_content, max_width=300),
tooltip="대전과학고등학교"
).add_to(m)
# Process all bird location files
location_files = glob.glob('./DB/bird_locations_json/bird_locations_*.json')
for file_path in location_files:
bird_name = file_path.split('_')[-1].replace('.json', '')
try:
with open(file_path, 'r', encoding='utf-8') as f:
locations_data = json.load(f)
if bird_name in locations_data and locations_data[bird_name]:
for location_info in locations_data[bird_name]:
try:
latitude = location_info['latitude']
longitude = location_info['longitude']
location = location_info['location'].lstrip() # Remove leading whitespace
image_filename = location_info['image_filename']
# Create popup content with image from dataset
popup_content = f"""
<div>
<h4>{bird_name}</h4>
<p>{location}</p>
{create_image_popup(bird_name, image_filename)}
</div>
"""
# Add marker
folium.CircleMarker(
location=[latitude, longitude],
radius=4,
popup=folium.Popup(popup_content, max_width=300),
tooltip=bird_name,
color='red',
fill=True
).add_to(m)
except Exception as e:
print(f"Error processing location for {bird_name}: {e}")
continue
else:
print(f'No locations found for {bird_name}')
except Exception as e:
print(f"Error processing file {file_path}: {e}")
continue
return m._repr_html_()
def search_birds(search_term):
filtered_gallery = []
for bird_id, bird_info in bird_data.items():
if search_term.lower() in bird_info['common_name'].lower() or search_term.lower() in bird_info['scientific_name'].lower():
image_path = os.path.join(image_folder, f"{bird_id}.jpg")
filtered_gallery.append((image_path, f"{bird_info['common_name']}"))
return filtered_gallery
def main_page():
gallery = []
for bird_id, bird_info in bird_data.items():
image_path = os.path.join(image_folder, f"{bird_id}.jpg")
gallery.append((image_path, f"{bird_info['common_name']}"))
return gallery
def detail_page(evt: gr.SelectData):
image_path = evt.value['image']['path']
bird_id = os.path.basename(image_path).split('.')[0]
selected_bird = bird_data[bird_id]
info = f"""
# {selected_bird['common_name']} ({selected_bird['scientific_name']})
## 분류
- 문: {selected_bird['classification']['phylum']}
- 강: {selected_bird['classification']['class']}
- 목: {selected_bird['classification']['order']}
- 과: {selected_bird['classification']['family']}
- 속: {selected_bird['classification']['genus']}
## 생태적 특징
{selected_bird['ecological_characteristics']}
## 일반적 특징
{selected_bird['general_characteristics']}
"""
return image_path, info
def apply_test_transforms(inp):
out = transforms.functional.resize(inp, [224,224])
out = transforms.functional.to_tensor(out)
out = transforms.functional.normalize(out, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return out
def predict(model, filepath):
im = Image.open(filepath)
im_as_tensor = apply_test_transforms(im)
minibatch = torch.stack([im_as_tensor])
if torch.cuda.is_available():
minibatch = minibatch.cuda()
pred = model(minibatch)
_, classnum = torch.max(pred, 1)
print(classnum)
return classes[str(classnum.item())]
def classify_bird(image):
result = predict(model_ft, image)
return result
with gr.Blocks() as demo:
gr.Markdown("# BIORD")
gr.Markdown("## Bird's Information & Organized Regional Database")
# 대전과학고 지도 탭
with gr.Tab("대전과고 지도"):
map_html = gr.HTML(value=create_map())
# 조류 도감 탭
with gr.Tab("조류 도감"):
with gr.Row():
search_input = gr.Textbox(label="새 이름 검색", placeholder="검색하고 싶은 새의 이름을 입력하세요")
with gr.Row():
with gr.Column(scale=2):
gallery = gr.Gallery(value=main_page(), columns=40, rows=6, height=660)
with gr.Column(scale=3):
selected_image = gr.Image(label="선택한 새")
info = gr.Markdown(label="상세 정보")
search_input.change(search_birds, inputs=[search_input], outputs=[gallery])
gallery.select(detail_page, None, [selected_image, info])
# 조류 동정 탭
with gr.Tab("조류 동정"):
image_input = gr.Image(type="filepath")
classify_btn = gr.Button("예측하기")
output = gr.Textbox(label="예측 결과")
classify_btn.click(fn=classify_bird, inputs=image_input, outputs=output)
# 애플리케이션 실행
demo.launch() |