import gradio as gr
import json
import os
import torch
from torchvision import transforms
from PIL import Image
import folium
import base64
import glob
import warnings
from datasets import load_dataset
import io
from zipfile import ZipFile
warnings.filterwarnings("ignore", category=FutureWarning)
# Load dataset from Hugging Face
dataset = load_dataset("hangunwoo07/Naturing_Bird_Data")
dataset = dataset['train'] # Access the train split directly
# Load bird data
with open('DB/bird_data.json', 'r', encoding='utf-8') as f:
bird_data = json.load(f)
# Load model and classes
model_ft = torch.load('bird_detection_model.pth', map_location=torch.device('cpu'))
model_ft.eval()
with open('DB/class_names.json', 'r', encoding='utf-8') as f:
classes = json.load(f)
# Constants
DSHS_LOCATION = [36.373719, 127.370415]
def create_image_popup(bird_name, image_filename):
target_image = '/'.join(bird_name, image_filename.split())
row_index = metadata[metadata['image'] == target_image].index[0] - 1
try:
with open(dataset['image'][row_index], 'rb') as image_file:
encoded = base64.b64encode(image_file.read()).decode()
return f'
'
except:
return ''
def create_map():
"""
Create an interactive map with bird sightings
"""
m = folium.Map(location=DSHS_LOCATION, zoom_start=15)
dshs_popup_content = f"""
대전과학고등학교
"""
# 대전과학고 마커 추가
folium.Marker(
DSHS_LOCATION,
popup=folium.Popup(dshs_popup_content, max_width=300),
tooltip="대전과학고등학교"
).add_to(m)
# Process all bird location files
location_files = glob.glob('./DB/bird_locations_json/bird_locations_*.json')
for file_path in location_files:
bird_name = file_path.split('_')[-1].replace('.json', '')
try:
with open(file_path, 'r', encoding='utf-8') as f:
locations_data = json.load(f)
if bird_name in locations_data and locations_data[bird_name]:
for location_info in locations_data[bird_name]:
try:
latitude = location_info['latitude']
longitude = location_info['longitude']
location = location_info['location'].lstrip() # Remove leading whitespace
image_filename = location_info['image_filename']
# Create popup content with image from dataset
popup_content = f"""
{bird_name}
{location}
{create_image_popup(bird_name, image_filename)}
"""
# Add marker
folium.CircleMarker(
location=[latitude, longitude],
radius=4,
popup=folium.Popup(popup_content, max_width=300),
tooltip=bird_name,
color='red',
fill=True
).add_to(m)
except Exception as e:
print(f"Error processing location for {bird_name}: {e}")
continue
else:
print(f'No locations found for {bird_name}')
except Exception as e:
print(f"Error processing file {file_path}: {e}")
continue
return m._repr_html_()
def search_birds(search_term):
filtered_gallery = []
for bird_id, bird_info in bird_data.items():
if search_term.lower() in bird_info['common_name'].lower() or search_term.lower() in bird_info['scientific_name'].lower():
image_path = os.path.join(image_folder, f"{bird_id}.jpg")
filtered_gallery.append((image_path, f"{bird_info['common_name']}"))
return filtered_gallery
def main_page():
gallery = []
for bird_id, bird_info in bird_data.items():
image_path = os.path.join(image_folder, f"{bird_id}.jpg")
gallery.append((image_path, f"{bird_info['common_name']}"))
return gallery
def detail_page(evt: gr.SelectData):
image_path = evt.value['image']['path']
bird_id = os.path.basename(image_path).split('.')[0]
selected_bird = bird_data[bird_id]
info = f"""
# {selected_bird['common_name']} ({selected_bird['scientific_name']})
## 분류
- 문: {selected_bird['classification']['phylum']}
- 강: {selected_bird['classification']['class']}
- 목: {selected_bird['classification']['order']}
- 과: {selected_bird['classification']['family']}
- 속: {selected_bird['classification']['genus']}
## 생태적 특징
{selected_bird['ecological_characteristics']}
## 일반적 특징
{selected_bird['general_characteristics']}
"""
return image_path, info
def apply_test_transforms(inp):
out = transforms.functional.resize(inp, [224,224])
out = transforms.functional.to_tensor(out)
out = transforms.functional.normalize(out, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return out
def predict(model, filepath):
im = Image.open(filepath)
im_as_tensor = apply_test_transforms(im)
minibatch = torch.stack([im_as_tensor])
if torch.cuda.is_available():
minibatch = minibatch.cuda()
pred = model(minibatch)
_, classnum = torch.max(pred, 1)
print(classnum)
return classes[str(classnum.item())]
def classify_bird(image):
result = predict(model_ft, image)
return result
with gr.Blocks() as demo:
gr.Markdown("# BIORD")
gr.Markdown("## Bird's Information & Organized Regional Database")
# 대전과학고 지도 탭
with gr.Tab("대전과고 지도"):
map_html = gr.HTML(value=create_map())
# 조류 도감 탭
with gr.Tab("조류 도감"):
with gr.Row():
search_input = gr.Textbox(label="새 이름 검색", placeholder="검색하고 싶은 새의 이름을 입력하세요")
with gr.Row():
with gr.Column(scale=2):
gallery = gr.Gallery(value=main_page(), columns=40, rows=6, height=660)
with gr.Column(scale=3):
selected_image = gr.Image(label="선택한 새")
info = gr.Markdown(label="상세 정보")
search_input.change(search_birds, inputs=[search_input], outputs=[gallery])
gallery.select(detail_page, None, [selected_image, info])
# 조류 동정 탭
with gr.Tab("조류 동정"):
image_input = gr.Image(type="filepath")
classify_btn = gr.Button("예측하기")
output = gr.Textbox(label="예측 결과")
classify_btn.click(fn=classify_bird, inputs=image_input, outputs=output)
# 애플리케이션 실행
demo.launch()