|
|
"""Gradio web application for snow leopard identification and catalog exploration. |
|
|
|
|
|
This interactive web interface provides an easy-to-use frontend for the snow |
|
|
leopard identification system. Users can upload images, view matches against the catalog, |
|
|
and explore reference leopards through a browser-based UI powered by Gradio. |
|
|
|
|
|
Features: |
|
|
- Upload snow leopard images or select from examples |
|
|
- Run full identification pipeline with GDINO+SAM segmentation |
|
|
- View top-K matches with Wasserstein distance scores |
|
|
- Explore complete leopard catalog with thumbnails |
|
|
- Visualize matched keypoints between query and catalog images |
|
|
|
|
|
Usage: |
|
|
# Local testing with uv: |
|
|
uv sync |
|
|
uv run python app.py |
|
|
|
|
|
# Deployed on Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
SPACE_ROOT = Path(__file__).parent |
|
|
sys.path.insert(0, str(SPACE_ROOT / "src")) |
|
|
|
|
|
import logging |
|
|
import shutil |
|
|
import tempfile |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from huggingface_hub import hf_hub_download |
|
|
from PIL import Image |
|
|
|
|
|
from snowleopard_reid.cache import ( |
|
|
filter_cached_matches, |
|
|
generate_visualizations_from_npz, |
|
|
is_cached, |
|
|
load_cached_results, |
|
|
) |
|
|
from snowleopard_reid.catalog import ( |
|
|
get_available_body_parts, |
|
|
get_available_locations, |
|
|
get_catalog_metadata_for_id, |
|
|
load_catalog_index, |
|
|
load_leopard_metadata, |
|
|
) |
|
|
from snowleopard_reid.data_setup import ensure_data_extracted |
|
|
from snowleopard_reid.pipeline.stages import ( |
|
|
run_feature_extraction_stage, |
|
|
run_matching_stage, |
|
|
run_preprocess_stage, |
|
|
run_segmentation_stage, |
|
|
select_best_mask, |
|
|
) |
|
|
from snowleopard_reid.pipeline.stages.segmentation import ( |
|
|
load_gdino_model, |
|
|
load_sam_predictor, |
|
|
) |
|
|
from snowleopard_reid.visualization import ( |
|
|
draw_keypoints_overlay, |
|
|
draw_matched_keypoints, |
|
|
draw_side_by_side_comparison, |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
CATALOG_ROOT = SPACE_ROOT / "data" / "catalog" |
|
|
SAM_CHECKPOINT_DIR = SPACE_ROOT / "data" / "models" |
|
|
SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth" |
|
|
EXAMPLES_DIR = SPACE_ROOT / "data" / "examples" |
|
|
GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base" |
|
|
TEXT_PROMPT = "a snow leopard." |
|
|
TOP_K_DEFAULT = 5 |
|
|
SAM_MODEL_TYPE = "vit_l" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AppConfig: |
|
|
"""Configuration for the Snow Leopard ID UI application.""" |
|
|
|
|
|
model_path: Path | None |
|
|
catalog_root: Path |
|
|
examples_dir: Path |
|
|
top_k: int |
|
|
port: int |
|
|
share: bool |
|
|
|
|
|
sam_checkpoint_path: Path |
|
|
sam_model_type: str |
|
|
gdino_model_id: str |
|
|
text_prompt: str |
|
|
|
|
|
|
|
|
def ensure_sam_model() -> Path: |
|
|
"""Download SAM HQ model if not present. |
|
|
|
|
|
Returns: |
|
|
Path to the SAM HQ checkpoint file |
|
|
""" |
|
|
sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME |
|
|
if not sam_path.exists(): |
|
|
logger.info("Downloading SAM HQ model (1.6GB)...") |
|
|
SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
hf_hub_download( |
|
|
repo_id="lkeab/hq-sam", |
|
|
filename=SAM_CHECKPOINT_NAME, |
|
|
local_dir=SAM_CHECKPOINT_DIR, |
|
|
) |
|
|
logger.info("SAM HQ model downloaded successfully") |
|
|
return sam_path |
|
|
|
|
|
|
|
|
def get_available_extractors(catalog_root: Path) -> list[str]: |
|
|
"""Get list of available feature extractors from catalog. |
|
|
|
|
|
Args: |
|
|
catalog_root: Root directory of the leopard catalog |
|
|
|
|
|
Returns: |
|
|
List of available extractor names (e.g., ['sift', 'superpoint']) |
|
|
""" |
|
|
try: |
|
|
catalog_index = load_catalog_index(catalog_root) |
|
|
extractors = list(catalog_index.get("feature_extractors", {}).keys()) |
|
|
if not extractors: |
|
|
logger.warning(f"No extractors found in catalog at {catalog_root}") |
|
|
return ["sift"] |
|
|
return extractors |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load catalog index: {e}") |
|
|
return ["sift"] |
|
|
|
|
|
|
|
|
|
|
|
LOADED_MODELS = {} |
|
|
|
|
|
|
|
|
def load_catalog_data(config: AppConfig): |
|
|
"""Load catalog index and individual leopard metadata. |
|
|
|
|
|
Args: |
|
|
config: Application configuration containing catalog_root |
|
|
|
|
|
Returns: |
|
|
Tuple of (catalog_index, individuals_data) |
|
|
""" |
|
|
catalog_index_path = config.catalog_root / "catalog_index.yaml" |
|
|
|
|
|
|
|
|
with open(catalog_index_path) as f: |
|
|
catalog_index = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
individuals_data = [] |
|
|
for individual in catalog_index["individuals"]: |
|
|
metadata_path = config.catalog_root / individual["metadata_path"] |
|
|
with open(metadata_path) as f: |
|
|
leopard_metadata = yaml.safe_load(f) |
|
|
individuals_data.append(leopard_metadata) |
|
|
|
|
|
return catalog_index, individuals_data |
|
|
|
|
|
|
|
|
def initialize_models(config: AppConfig): |
|
|
"""Load models at startup for faster inference. |
|
|
|
|
|
Args: |
|
|
config: Application configuration containing model paths |
|
|
""" |
|
|
logger.info("Initializing models...") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
if device == "cuda": |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
|
logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)") |
|
|
|
|
|
|
|
|
logger.info(f"Loading Grounding DINO model: {config.gdino_model_id}") |
|
|
gdino_processor, gdino_model = load_gdino_model( |
|
|
model_id=config.gdino_model_id, |
|
|
device=device, |
|
|
) |
|
|
LOADED_MODELS["gdino_processor"] = gdino_processor |
|
|
LOADED_MODELS["gdino_model"] = gdino_model |
|
|
logger.info("Grounding DINO model loaded successfully") |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"Loading SAM HQ model from {config.sam_checkpoint_path} (type: {config.sam_model_type})" |
|
|
) |
|
|
sam_predictor = load_sam_predictor( |
|
|
checkpoint_path=config.sam_checkpoint_path, |
|
|
model_type=config.sam_model_type, |
|
|
device=device, |
|
|
) |
|
|
LOADED_MODELS["sam_predictor"] = sam_predictor |
|
|
logger.info("SAM HQ model loaded successfully") |
|
|
|
|
|
|
|
|
LOADED_MODELS["device"] = device |
|
|
LOADED_MODELS["catalog_root"] = config.catalog_root |
|
|
LOADED_MODELS["text_prompt"] = config.text_prompt |
|
|
|
|
|
logger.info("Models initialized successfully") |
|
|
|
|
|
|
|
|
def _load_from_cache( |
|
|
example_path: str, |
|
|
extractor: str, |
|
|
config: "AppConfig", |
|
|
filter_locations: list[str] | None = None, |
|
|
filter_body_parts: list[str] | None = None, |
|
|
top_k: int = 5, |
|
|
): |
|
|
"""Load cached pipeline results with optional filtering and return UI component updates. |
|
|
|
|
|
Supports the v2.0 cache format which stores ALL matches with location/body_part |
|
|
metadata, enabling client-side filtering without re-running the pipeline. |
|
|
|
|
|
Args: |
|
|
example_path: Path to the example image |
|
|
extractor: Feature extractor name |
|
|
config: Application configuration |
|
|
filter_locations: Optional list of locations to filter by |
|
|
filter_body_parts: Optional list of body parts to filter by |
|
|
top_k: Number of top matches to return after filtering |
|
|
|
|
|
Returns: |
|
|
Tuple of 23 UI components matching run_identification output |
|
|
""" |
|
|
|
|
|
cached = load_cached_results(example_path, extractor) |
|
|
predictions = cached["predictions"] |
|
|
|
|
|
|
|
|
if "all_matches" in predictions: |
|
|
all_matches = predictions["all_matches"] |
|
|
else: |
|
|
|
|
|
all_matches = predictions.get("matches", []) |
|
|
|
|
|
|
|
|
matches = filter_cached_matches( |
|
|
all_matches=all_matches, |
|
|
filter_locations=filter_locations, |
|
|
filter_body_parts=filter_body_parts, |
|
|
top_k=top_k, |
|
|
) |
|
|
|
|
|
if not matches: |
|
|
|
|
|
return ( |
|
|
"No matches found with the selected filters", |
|
|
cached["segmentation_image"], |
|
|
cached["cropped_image"], |
|
|
cached["keypoints_image"], |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Generating visualizations for {len(matches)} filtered matches...") |
|
|
match_visualizations, clean_comparison_visualizations = ( |
|
|
generate_visualizations_from_npz( |
|
|
pairwise_dir=cached["pairwise_dir"], |
|
|
matches=matches, |
|
|
cropped_image_path=cached["pairwise_dir"].parent / "cropped.png", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
LOADED_MODELS["current_match_visualizations"] = match_visualizations |
|
|
LOADED_MODELS["current_clean_comparison_visualizations"] = ( |
|
|
clean_comparison_visualizations |
|
|
) |
|
|
LOADED_MODELS["current_enriched_matches"] = matches |
|
|
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts |
|
|
LOADED_MODELS["current_temp_dir"] = None |
|
|
|
|
|
|
|
|
top_match = matches[0] |
|
|
top_leopard_name = top_match["leopard_name"] |
|
|
top_wasserstein = top_match["wasserstein"] |
|
|
|
|
|
|
|
|
if top_wasserstein >= 0.12: |
|
|
confidence_indicator = "🔵" |
|
|
elif top_wasserstein >= 0.07: |
|
|
confidence_indicator = "🟢" |
|
|
elif top_wasserstein >= 0.04: |
|
|
confidence_indicator = "🟡" |
|
|
else: |
|
|
confidence_indicator = "🔴" |
|
|
|
|
|
result_text = f"## {confidence_indicator} {top_leopard_name.title()}" |
|
|
|
|
|
|
|
|
dataset_samples = [] |
|
|
for match in matches: |
|
|
rank = match["rank"] |
|
|
leopard_name = match["leopard_name"] |
|
|
wasserstein = match["wasserstein"] |
|
|
|
|
|
|
|
|
location = match.get("location", "unknown") |
|
|
if location == "unknown": |
|
|
catalog_id = match["catalog_id"] |
|
|
catalog_metadata = get_catalog_metadata_for_id( |
|
|
config.catalog_root, catalog_id |
|
|
) |
|
|
if catalog_metadata: |
|
|
img_path_parts = Path(catalog_metadata["image_path"]).parts |
|
|
try: |
|
|
db_idx = img_path_parts.index("database") |
|
|
if db_idx + 1 < len(img_path_parts): |
|
|
location = img_path_parts[db_idx + 1] |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
if wasserstein >= 0.12: |
|
|
indicator = "🔵" |
|
|
elif wasserstein >= 0.07: |
|
|
indicator = "🟢" |
|
|
elif wasserstein >= 0.04: |
|
|
indicator = "🟡" |
|
|
else: |
|
|
indicator = "🔴" |
|
|
|
|
|
dataset_samples.append( |
|
|
[ |
|
|
rank, |
|
|
indicator, |
|
|
leopard_name.title(), |
|
|
location.replace("_", " ").title(), |
|
|
f"{wasserstein:.4f}", |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
rank1_details = load_match_details_for_rank(rank=1) |
|
|
|
|
|
|
|
|
return ( |
|
|
result_text, |
|
|
cached["segmentation_image"], |
|
|
cached["cropped_image"], |
|
|
cached["keypoints_image"], |
|
|
dataset_samples, |
|
|
*rank1_details, |
|
|
) |
|
|
|
|
|
|
|
|
def run_identification( |
|
|
image, |
|
|
extractor: str, |
|
|
top_k: int, |
|
|
selected_locations: list[str], |
|
|
selected_body_parts: list[str], |
|
|
example_path: str | None, |
|
|
config: AppConfig, |
|
|
): |
|
|
"""Run snow leopard identification pipeline on uploaded image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image from Gradio upload |
|
|
extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked') |
|
|
top_k: Number of top matches to return |
|
|
selected_locations: List of selected locations (includes "all" for no filtering) |
|
|
selected_body_parts: List of selected body parts (includes "all" for no filtering) |
|
|
example_path: Path to example image if selected from examples (for cache lookup) |
|
|
config: Application configuration |
|
|
|
|
|
Returns: |
|
|
Tuple of UI components to update |
|
|
""" |
|
|
if image is None: |
|
|
|
|
|
return ( |
|
|
"Please upload an image first", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
filter_locations = ( |
|
|
None |
|
|
if not selected_locations or "all" in selected_locations |
|
|
else selected_locations |
|
|
) |
|
|
filter_body_parts_parsed = ( |
|
|
None |
|
|
if not selected_body_parts or "all" in selected_body_parts |
|
|
else selected_body_parts |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Cache check: example_path={example_path}, extractor={extractor}") |
|
|
if example_path: |
|
|
cache_exists = is_cached(example_path, extractor) |
|
|
logger.info(f"is_cached() returned: {cache_exists}") |
|
|
else: |
|
|
cache_exists = False |
|
|
logger.info("No example_path provided, skipping cache") |
|
|
|
|
|
|
|
|
if example_path and cache_exists: |
|
|
logger.info(f"Cache hit for {example_path} with {extractor}") |
|
|
if filter_locations or filter_body_parts_parsed: |
|
|
logger.info( |
|
|
f" Applying filters: locations={filter_locations}, body_parts={filter_body_parts_parsed}" |
|
|
) |
|
|
try: |
|
|
return _load_from_cache( |
|
|
example_path, |
|
|
extractor, |
|
|
config, |
|
|
filter_locations=filter_locations, |
|
|
filter_body_parts=filter_body_parts_parsed, |
|
|
top_k=int(top_k), |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Cache load failed, running pipeline: {e}") |
|
|
|
|
|
|
|
|
|
|
|
filter_body_parts = filter_body_parts_parsed |
|
|
|
|
|
|
|
|
if filter_locations or filter_body_parts: |
|
|
filter_desc = [] |
|
|
if filter_locations: |
|
|
filter_desc.append(f"locations: {', '.join(filter_locations)}") |
|
|
if filter_body_parts: |
|
|
filter_desc.append(f"body parts: {', '.join(filter_body_parts)}") |
|
|
logger.info(f"Applied filters - {' | '.join(filter_desc)}") |
|
|
else: |
|
|
logger.info("No filters applied - matching against entire catalog") |
|
|
|
|
|
try: |
|
|
|
|
|
temp_dir = Path(tempfile.mkdtemp(prefix="snowleopard_id_")) |
|
|
temp_image_path = temp_dir / "query.jpg" |
|
|
|
|
|
|
|
|
logger.info(f"Image type: {type(image)}") |
|
|
logger.info(f"Image mode: {image.mode if hasattr(image, 'mode') else 'N/A'}") |
|
|
logger.info(f"Image size: {image.size if hasattr(image, 'size') else 'N/A'}") |
|
|
image.save(temp_image_path, quality=95) |
|
|
|
|
|
|
|
|
saved_size = temp_image_path.stat().st_size |
|
|
logger.info(f"Saved image size: {saved_size / 1024 / 1024:.2f} MB") |
|
|
|
|
|
logger.info(f"Processing query image: {temp_image_path}") |
|
|
|
|
|
device = LOADED_MODELS.get("device", "cpu") |
|
|
|
|
|
|
|
|
logger.info("Running GDINO+SAM segmentation...") |
|
|
gdino_processor = LOADED_MODELS.get("gdino_processor") |
|
|
gdino_model = LOADED_MODELS.get("gdino_model") |
|
|
sam_predictor = LOADED_MODELS.get("sam_predictor") |
|
|
text_prompt = LOADED_MODELS.get("text_prompt", "a snow leopard.") |
|
|
|
|
|
seg_stage = run_segmentation_stage( |
|
|
image_path=temp_image_path, |
|
|
strategy="gdino_sam", |
|
|
confidence_threshold=0.2, |
|
|
device=device, |
|
|
gdino_processor=gdino_processor, |
|
|
gdino_model=gdino_model, |
|
|
sam_predictor=sam_predictor, |
|
|
text_prompt=text_prompt, |
|
|
box_threshold=0.30, |
|
|
text_threshold=0.20, |
|
|
) |
|
|
|
|
|
predictions = seg_stage["data"]["predictions"] |
|
|
logger.info(f"Number of predictions: {len(predictions)}") |
|
|
|
|
|
if not predictions: |
|
|
logger.warning("No predictions found from segmentation") |
|
|
logger.warning(f"Full segmentation stage: {seg_stage}") |
|
|
|
|
|
return ( |
|
|
"No snow leopards detected in image", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Selecting best mask...") |
|
|
selected_idx, selected_pred = select_best_mask( |
|
|
predictions, |
|
|
strategy="confidence_area", |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Preprocessing query image...") |
|
|
prep_stage = run_preprocess_stage( |
|
|
image_path=temp_image_path, |
|
|
mask=selected_pred["mask"], |
|
|
padding=5, |
|
|
) |
|
|
|
|
|
cropped_image_pil = prep_stage["data"]["cropped_image"] |
|
|
|
|
|
|
|
|
cropped_path = temp_dir / "cropped.jpg" |
|
|
cropped_image_pil.save(cropped_path) |
|
|
|
|
|
|
|
|
logger.info(f"Extracting features using {extractor.upper()}...") |
|
|
feat_stage = run_feature_extraction_stage( |
|
|
image=cropped_image_pil, |
|
|
extractor=extractor, |
|
|
max_keypoints=2048, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
query_features = feat_stage["data"]["features"] |
|
|
|
|
|
|
|
|
logger.info("Matching against catalog...") |
|
|
pairwise_dir = temp_dir / "pairwise" |
|
|
pairwise_dir.mkdir(exist_ok=True) |
|
|
|
|
|
match_stage = run_matching_stage( |
|
|
query_features=query_features, |
|
|
catalog_path=config.catalog_root, |
|
|
top_k=top_k, |
|
|
extractor=extractor, |
|
|
device=device, |
|
|
query_image_path=str(cropped_path), |
|
|
pairwise_output_dir=pairwise_dir, |
|
|
filter_locations=filter_locations, |
|
|
filter_body_parts=filter_body_parts, |
|
|
) |
|
|
|
|
|
matches = match_stage["data"]["matches"] |
|
|
|
|
|
if not matches: |
|
|
|
|
|
return ( |
|
|
"No matches found in catalog", |
|
|
None, |
|
|
cropped_image_pil, |
|
|
None, |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
top_match = matches[0] |
|
|
top_leopard_name = top_match["leopard_name"] |
|
|
top_wasserstein = top_match["wasserstein"] |
|
|
|
|
|
|
|
|
if top_wasserstein >= 0.12: |
|
|
confidence_indicator = "🔵" |
|
|
elif top_wasserstein >= 0.07: |
|
|
confidence_indicator = "🟢" |
|
|
elif top_wasserstein >= 0.04: |
|
|
confidence_indicator = "🟡" |
|
|
else: |
|
|
confidence_indicator = "🔴" |
|
|
|
|
|
result_text = f"## {confidence_indicator} {top_leopard_name.title()}" |
|
|
|
|
|
|
|
|
seg_viz = create_segmentation_viz( |
|
|
image_path=temp_image_path, mask=selected_pred["mask"] |
|
|
) |
|
|
|
|
|
|
|
|
extracted_kpts_viz = None |
|
|
try: |
|
|
|
|
|
query_kpts = query_features["keypoints"].cpu().numpy() |
|
|
extracted_kpts_viz = draw_keypoints_overlay( |
|
|
image_path=cropped_path, |
|
|
keypoints=query_kpts, |
|
|
max_keypoints=500, |
|
|
color="blue", |
|
|
ps=10, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error creating extracted keypoints visualization: {e}") |
|
|
|
|
|
|
|
|
dataset_samples = [] |
|
|
match_visualizations = {} |
|
|
clean_comparison_visualizations = {} |
|
|
|
|
|
for match in matches: |
|
|
rank = match["rank"] |
|
|
leopard_name = match["leopard_name"] |
|
|
wasserstein = match["wasserstein"] |
|
|
catalog_img_path = Path(match["filepath"]) |
|
|
|
|
|
|
|
|
catalog_id = match["catalog_id"] |
|
|
catalog_metadata = get_catalog_metadata_for_id( |
|
|
config.catalog_root, catalog_id |
|
|
) |
|
|
location = "unknown" |
|
|
if catalog_metadata: |
|
|
|
|
|
img_path_parts = Path(catalog_metadata["image_path"]).parts |
|
|
if len(img_path_parts) >= 3: |
|
|
|
|
|
try: |
|
|
db_idx = img_path_parts.index("database") |
|
|
if db_idx + 1 < len(img_path_parts): |
|
|
location = img_path_parts[db_idx + 1] |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
if wasserstein >= 0.12: |
|
|
indicator = "🔵" |
|
|
elif wasserstein >= 0.07: |
|
|
indicator = "🟢" |
|
|
elif wasserstein >= 0.04: |
|
|
indicator = "🟡" |
|
|
else: |
|
|
indicator = "🔴" |
|
|
|
|
|
|
|
|
npz_path = pairwise_dir / f"rank_{rank:02d}_{match['catalog_id']}.npz" |
|
|
if npz_path.exists(): |
|
|
try: |
|
|
pairwise_data = np.load(npz_path) |
|
|
|
|
|
|
|
|
match_viz = draw_matched_keypoints( |
|
|
query_image_path=cropped_path, |
|
|
catalog_image_path=catalog_img_path, |
|
|
query_keypoints=pairwise_data["query_keypoints"], |
|
|
catalog_keypoints=pairwise_data["catalog_keypoints"], |
|
|
match_scores=pairwise_data["match_scores"], |
|
|
max_matches=100, |
|
|
) |
|
|
match_visualizations[rank] = match_viz |
|
|
|
|
|
|
|
|
clean_viz = draw_side_by_side_comparison( |
|
|
query_image_path=cropped_path, |
|
|
catalog_image_path=catalog_img_path, |
|
|
) |
|
|
clean_comparison_visualizations[rank] = clean_viz |
|
|
except Exception as e: |
|
|
logger.error(f"Error creating visualizations for rank {rank}: {e}") |
|
|
|
|
|
|
|
|
dataset_samples.append( |
|
|
[ |
|
|
rank, |
|
|
indicator, |
|
|
leopard_name.title(), |
|
|
location.replace("_", " ").title(), |
|
|
f"{wasserstein:.4f}", |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
LOADED_MODELS["current_match_visualizations"] = match_visualizations |
|
|
LOADED_MODELS["current_clean_comparison_visualizations"] = ( |
|
|
clean_comparison_visualizations |
|
|
) |
|
|
LOADED_MODELS["current_enriched_matches"] = matches |
|
|
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts |
|
|
LOADED_MODELS["current_temp_dir"] = temp_dir |
|
|
|
|
|
|
|
|
rank1_details = load_match_details_for_rank(rank=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
result_text, |
|
|
seg_viz, |
|
|
cropped_image_pil, |
|
|
extracted_kpts_viz, |
|
|
dataset_samples, |
|
|
|
|
|
*rank1_details, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing image: {e}", exc_info=True) |
|
|
|
|
|
return ( |
|
|
f"Error processing image: {str(e)}", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
def create_segmentation_viz(image_path, mask): |
|
|
"""Create visualization of segmentation mask overlaid on image.""" |
|
|
|
|
|
img = cv2.imread(str(image_path)) |
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
if mask.shape[:2] != img_rgb.shape[:2]: |
|
|
mask_resized = cv2.resize( |
|
|
mask.astype(np.uint8), |
|
|
(img_rgb.shape[1], img_rgb.shape[0]), |
|
|
interpolation=cv2.INTER_NEAREST, |
|
|
) |
|
|
else: |
|
|
mask_resized = mask |
|
|
|
|
|
|
|
|
overlay = img_rgb.copy() |
|
|
overlay[mask_resized > 0] = [255, 0, 0] |
|
|
|
|
|
|
|
|
alpha = 0.4 |
|
|
blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0) |
|
|
|
|
|
return Image.fromarray(blended) |
|
|
|
|
|
|
|
|
def load_match_details_for_rank(rank: int) -> tuple: |
|
|
"""Load all match details (visualizations + galleries) for a specific rank. |
|
|
|
|
|
This is a reusable helper function that encapsulates the logic for loading |
|
|
match visualizations, galleries, and metadata for a given rank. Used by both |
|
|
the automatic rank 1 display after pipeline completion and the interactive |
|
|
row selection handler. |
|
|
|
|
|
Args: |
|
|
rank: The rank to load (1-indexed) |
|
|
|
|
|
Returns: |
|
|
Tuple of 18 Gradio component updates: |
|
|
(matched_kpts_viz, clean_comparison_viz, header, |
|
|
head_indicator, left_flank_indicator, right_flank_indicator, tail_indicator, misc_indicator, |
|
|
head_empty_message, left_flank_empty_message, right_flank_empty_message, |
|
|
tail_empty_message, misc_empty_message, |
|
|
gallery_head, gallery_left_flank, gallery_right_flank, gallery_tail, gallery_misc) |
|
|
""" |
|
|
|
|
|
match_visualizations = LOADED_MODELS.get("current_match_visualizations", {}) |
|
|
clean_comparison_visualizations = LOADED_MODELS.get( |
|
|
"current_clean_comparison_visualizations", {} |
|
|
) |
|
|
enriched_matches = LOADED_MODELS.get("current_enriched_matches", []) |
|
|
filter_body_parts = LOADED_MODELS.get("current_filter_body_parts") |
|
|
catalog_root = LOADED_MODELS.get("catalog_root") |
|
|
|
|
|
|
|
|
selected_match = None |
|
|
for match in enriched_matches: |
|
|
if match["rank"] == rank: |
|
|
selected_match = match |
|
|
break |
|
|
|
|
|
if not selected_match or rank not in match_visualizations: |
|
|
|
|
|
return ( |
|
|
gr.update(value=None), |
|
|
gr.update(value=None), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(value=""), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
gr.update(value=[]), |
|
|
) |
|
|
|
|
|
|
|
|
match_viz = match_visualizations[rank] |
|
|
clean_viz = clean_comparison_visualizations.get(rank) |
|
|
|
|
|
|
|
|
leopard_name = selected_match["leopard_name"] |
|
|
header_text = f"## Reference Images for {leopard_name.title()}" |
|
|
|
|
|
|
|
|
galleries = {} |
|
|
if catalog_root: |
|
|
try: |
|
|
|
|
|
location = None |
|
|
filepath = Path(selected_match["filepath"]) |
|
|
parts = filepath.parts |
|
|
if "database" in parts: |
|
|
db_idx = parts.index("database") |
|
|
if db_idx + 1 < len(parts): |
|
|
location = parts[db_idx + 1] |
|
|
|
|
|
galleries = load_matched_individual_gallery_by_body_part( |
|
|
catalog_root=catalog_root, |
|
|
leopard_name=leopard_name, |
|
|
location=location, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading gallery for {leopard_name}: {e}") |
|
|
|
|
|
galleries = { |
|
|
"head": [], |
|
|
"left_flank": [], |
|
|
"right_flank": [], |
|
|
"tail": [], |
|
|
"misc": [], |
|
|
} |
|
|
|
|
|
|
|
|
def get_indicator(body_part: str) -> str: |
|
|
"""Return star if body part was in filter, empty string otherwise.""" |
|
|
if filter_body_parts and body_part in filter_body_parts: |
|
|
return "* (filtered)" |
|
|
return "" |
|
|
|
|
|
|
|
|
def is_empty(body_part: str) -> bool: |
|
|
"""Return True if no images for this body part.""" |
|
|
return len(galleries.get(body_part, [])) == 0 |
|
|
|
|
|
return ( |
|
|
gr.update(value=match_viz), |
|
|
gr.update(value=clean_viz), |
|
|
gr.update(value=header_text), |
|
|
gr.update(value=get_indicator("head")), |
|
|
gr.update(value=get_indicator("left_flank")), |
|
|
gr.update(value=get_indicator("right_flank")), |
|
|
gr.update(value=get_indicator("tail")), |
|
|
gr.update(value=get_indicator("misc")), |
|
|
gr.update(visible=is_empty("head")), |
|
|
gr.update(visible=is_empty("left_flank")), |
|
|
gr.update(visible=is_empty("right_flank")), |
|
|
gr.update(visible=is_empty("tail")), |
|
|
gr.update(visible=is_empty("misc")), |
|
|
gr.update( |
|
|
value=galleries.get("head", []), visible=not is_empty("head") |
|
|
), |
|
|
gr.update( |
|
|
value=galleries.get("left_flank", []), visible=not is_empty("left_flank") |
|
|
), |
|
|
gr.update( |
|
|
value=galleries.get("right_flank", []), visible=not is_empty("right_flank") |
|
|
), |
|
|
gr.update( |
|
|
value=galleries.get("tail", []), visible=not is_empty("tail") |
|
|
), |
|
|
gr.update( |
|
|
value=galleries.get("misc", []), visible=not is_empty("misc") |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def on_match_selected(evt: gr.SelectData): |
|
|
"""Handle selection of a match from the dataset table. |
|
|
|
|
|
Returns both visualizations, header, indicators, empty messages, |
|
|
and galleries organized by body part. |
|
|
""" |
|
|
|
|
|
if isinstance(evt.index, (list, tuple)): |
|
|
selected_row = evt.index[0] |
|
|
else: |
|
|
selected_row = evt.index |
|
|
|
|
|
selected_rank = selected_row + 1 |
|
|
|
|
|
|
|
|
return load_match_details_for_rank(selected_rank) |
|
|
|
|
|
|
|
|
def load_matched_individual_gallery_by_body_part( |
|
|
catalog_root: Path, |
|
|
leopard_name: str, |
|
|
location: str | None = None, |
|
|
) -> dict[str, list[tuple]]: |
|
|
"""Load all images for a matched individual organized by body part. |
|
|
|
|
|
Args: |
|
|
catalog_root: Path to catalog root directory |
|
|
leopard_name: Name of the matched individual (e.g., "karindas") |
|
|
location: Geographic location (e.g., "skycrest_valley") |
|
|
|
|
|
Returns: |
|
|
Dict mapping body part to list of (PIL.Image, caption) tuples: |
|
|
{ |
|
|
"head": [(img1, caption1), (img2, caption2), ...], |
|
|
"left_flank": [...], |
|
|
"right_flank": [...], |
|
|
"tail": [...], |
|
|
"misc": [...] |
|
|
} |
|
|
""" |
|
|
|
|
|
galleries = { |
|
|
"head": [], |
|
|
"left_flank": [], |
|
|
"right_flank": [], |
|
|
"tail": [], |
|
|
"misc": [], |
|
|
} |
|
|
|
|
|
|
|
|
if location: |
|
|
metadata_path = ( |
|
|
catalog_root / "database" / location / leopard_name / "metadata.yaml" |
|
|
) |
|
|
else: |
|
|
|
|
|
metadata_path = None |
|
|
database_dir = catalog_root / "database" |
|
|
if database_dir.exists(): |
|
|
for loc_dir in database_dir.iterdir(): |
|
|
if loc_dir.is_dir(): |
|
|
potential_path = loc_dir / leopard_name / "metadata.yaml" |
|
|
if potential_path.exists(): |
|
|
metadata_path = potential_path |
|
|
break |
|
|
|
|
|
if not metadata_path or not metadata_path.exists(): |
|
|
logger.warning(f"Metadata not found for {leopard_name}") |
|
|
return galleries |
|
|
|
|
|
try: |
|
|
metadata = load_leopard_metadata(metadata_path) |
|
|
|
|
|
|
|
|
for img_entry in metadata["reference_images"]: |
|
|
body_part = img_entry.get("body_part", "misc") |
|
|
|
|
|
|
|
|
if body_part not in galleries: |
|
|
body_part = "misc" |
|
|
|
|
|
|
|
|
img_path = catalog_root / "database" / img_entry["path"] |
|
|
|
|
|
try: |
|
|
img = Image.open(img_path) |
|
|
|
|
|
caption = body_part |
|
|
galleries[body_part].append((img, caption)) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading image {img_path}: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading metadata for {leopard_name}: {e}") |
|
|
|
|
|
return galleries |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
"""Clean up temporary files from previous run.""" |
|
|
temp_dir = LOADED_MODELS.get("current_temp_dir") |
|
|
if temp_dir and temp_dir.exists(): |
|
|
try: |
|
|
shutil.rmtree(temp_dir) |
|
|
logger.info(f"Cleaned up temporary directory: {temp_dir}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Error cleaning up temp directory: {e}") |
|
|
|
|
|
|
|
|
def create_leopard_tab(leopard_metadata, config: AppConfig): |
|
|
"""Create a tab for displaying a single leopard's images. |
|
|
|
|
|
Args: |
|
|
leopard_metadata: Metadata dictionary for the leopard individual |
|
|
config: Application configuration |
|
|
""" |
|
|
|
|
|
leopard_name = leopard_metadata.get("leopard_name") or leopard_metadata.get( |
|
|
"individual_name" |
|
|
) |
|
|
location = leopard_metadata.get("location", "unknown") |
|
|
total_images = leopard_metadata["statistics"]["total_reference_images"] |
|
|
|
|
|
|
|
|
body_parts = leopard_metadata["statistics"].get( |
|
|
"body_parts_represented", leopard_metadata["statistics"].get("body_parts", []) |
|
|
) |
|
|
body_parts_str = ", ".join(body_parts) if body_parts else "N/A" |
|
|
|
|
|
with gr.Tab(f"{leopard_name}"): |
|
|
|
|
|
gr.Markdown( |
|
|
f"### {leopard_name.title()}\n" |
|
|
f"**Location:** {location.replace('_', ' ').title()} | " |
|
|
f"**{total_images} images** | " |
|
|
f"**Body parts:** {body_parts_str}" |
|
|
) |
|
|
|
|
|
|
|
|
gallery_data = [] |
|
|
for img_entry in leopard_metadata["reference_images"]: |
|
|
img_path = config.catalog_root / "database" / img_entry["path"] |
|
|
body_part = img_entry.get("body_part", "unknown") |
|
|
try: |
|
|
img = Image.open(img_path) |
|
|
|
|
|
caption = body_part |
|
|
gallery_data.append((img, caption)) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading image {img_path}: {e}") |
|
|
|
|
|
|
|
|
gr.Gallery( |
|
|
value=gallery_data, |
|
|
label=f"Reference Images for {leopard_name.title()}", |
|
|
columns=6, |
|
|
height=700, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
|
|
|
def create_app(config: AppConfig): |
|
|
"""Create and configure the Gradio application. |
|
|
|
|
|
Args: |
|
|
config: Application configuration |
|
|
""" |
|
|
|
|
|
initialize_models(config) |
|
|
|
|
|
|
|
|
catalog_index, individuals_data = load_catalog_data(config) |
|
|
|
|
|
|
|
|
example_images = ( |
|
|
list(config.examples_dir.glob("*.jpg")) |
|
|
+ list(config.examples_dir.glob("*.JPG")) |
|
|
+ list(config.examples_dir.glob("*.png")) |
|
|
) |
|
|
|
|
|
example_images.sort(key=lambda x: (1 if "Ayima" in x.name else 0, x.name)) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Snow Leopard Identification") as app: |
|
|
|
|
|
selected_example_state = gr.State(value=None) |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("Identify Snow Leopard"): |
|
|
gr.Markdown(""" |
|
|
Upload a snow leopard image or select an example to identify which individual it is. |
|
|
The system will detect the leopard, extract distinctive features, and match against the catalog. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Snow Leopard Image", |
|
|
sources=["upload", "clipboard"], |
|
|
) |
|
|
|
|
|
examples_component = gr.Examples( |
|
|
examples=[[str(img)] for img in example_images], |
|
|
inputs=image_input, |
|
|
label="Example Images", |
|
|
) |
|
|
|
|
|
|
|
|
def on_example_select(evt: gr.SelectData): |
|
|
"""Update state when an example is selected.""" |
|
|
if evt.index is not None: |
|
|
return str(example_images[evt.index]) |
|
|
return None |
|
|
|
|
|
|
|
|
def check_if_example(img): |
|
|
"""Check if uploaded image matches an example path.""" |
|
|
|
|
|
|
|
|
return gr.update() |
|
|
|
|
|
examples_component.dataset.select( |
|
|
fn=on_example_select, |
|
|
outputs=[selected_example_state], |
|
|
) |
|
|
|
|
|
|
|
|
image_input.upload( |
|
|
fn=lambda: None, |
|
|
outputs=[selected_example_state], |
|
|
) |
|
|
|
|
|
|
|
|
available_locations = get_available_locations( |
|
|
config.catalog_root |
|
|
) |
|
|
location_filter = gr.Dropdown( |
|
|
choices=available_locations, |
|
|
value=["all"], |
|
|
multiselect=True, |
|
|
label="Filter by Location", |
|
|
info="Select locations to search (default: all locations)", |
|
|
) |
|
|
|
|
|
|
|
|
available_body_parts = get_available_body_parts( |
|
|
config.catalog_root |
|
|
) |
|
|
body_part_filter = gr.Dropdown( |
|
|
choices=available_body_parts, |
|
|
value=["all"], |
|
|
multiselect=True, |
|
|
label="Filter by Body Part", |
|
|
info="Select body parts to match (default: all body parts)", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Configuration", open=False): |
|
|
|
|
|
available_extractors = get_available_extractors( |
|
|
config.catalog_root |
|
|
) |
|
|
extractor_dropdown = gr.Dropdown( |
|
|
choices=available_extractors, |
|
|
value="sift" |
|
|
if "sift" in available_extractors |
|
|
else ( |
|
|
available_extractors[0] |
|
|
if available_extractors |
|
|
else "sift" |
|
|
), |
|
|
label="Feature Extractor", |
|
|
info=f"Available: {', '.join(available_extractors)}", |
|
|
scale=1, |
|
|
) |
|
|
|
|
|
|
|
|
top_k_input = gr.Number( |
|
|
value=config.top_k, |
|
|
label="Top-K Matches", |
|
|
info="Number of top matches to return", |
|
|
minimum=1, |
|
|
maximum=20, |
|
|
step=1, |
|
|
precision=0, |
|
|
scale=1, |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button( |
|
|
value="Identify Snow Leopard", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=4): |
|
|
|
|
|
result_text = gr.Markdown("") |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Model Internals"): |
|
|
gr.Markdown(""" |
|
|
View the internal processing steps: segmentation mask, cropped leopard, and extracted keypoints. |
|
|
""") |
|
|
with gr.Row(): |
|
|
seg_viz = gr.Image( |
|
|
label="Segmentation Overlay", |
|
|
type="pil", |
|
|
) |
|
|
cropped_image = gr.Image( |
|
|
label="Extracted Snow Leopard", |
|
|
type="pil", |
|
|
) |
|
|
extracted_kpts_viz = gr.Image( |
|
|
label="Extracted Keypoints", |
|
|
type="pil", |
|
|
) |
|
|
|
|
|
with gr.Tab("Top Matches"): |
|
|
gr.Markdown(""" |
|
|
Click a row to view detailed feature matching visualization and all reference images for that leopard. |
|
|
|
|
|
**Higher Wasserstein distance = better match** (typical range: 0.04-0.27) |
|
|
|
|
|
**Confidence Levels:** 🔵 Excellent (>=0.12) | 🟢 Good (>=0.07) | 🟡 Fair (>=0.04) | 🔴 Uncertain (<0.04) |
|
|
""") |
|
|
|
|
|
matches_dataset = gr.Dataframe( |
|
|
headers=[ |
|
|
"Rank", |
|
|
"Confidence", |
|
|
"Leopard Name", |
|
|
"Location", |
|
|
"Wasserstein", |
|
|
], |
|
|
label="Top Matches", |
|
|
wrap=True, |
|
|
col_count=(5, "fixed"), |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column() as viz_tabs: |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Matched Keypoints"): |
|
|
gr.Markdown( |
|
|
"Feature matching with keypoints and confidence-coded connecting lines. " |
|
|
"**Green** = high confidence, **Yellow** = medium, **Red** = low." |
|
|
) |
|
|
matched_kpts_viz = gr.Image( |
|
|
type="pil", |
|
|
show_label=False, |
|
|
) |
|
|
|
|
|
with gr.Tab("Clean Comparison"): |
|
|
gr.Markdown( |
|
|
"Side-by-side comparison without feature annotations. " |
|
|
"Useful for assessing overall visual similarity and spotting patterns." |
|
|
) |
|
|
clean_comparison_viz = gr.Image( |
|
|
type="pil", |
|
|
show_label=False, |
|
|
) |
|
|
|
|
|
|
|
|
selected_match_header = gr.Markdown( |
|
|
"", visible=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Head"): |
|
|
head_indicator = gr.Markdown("") |
|
|
head_empty_message = gr.Markdown( |
|
|
value='<div style="text-align: center; padding: 60px 20px; color: #888;">' |
|
|
'<p style="font-size: 16px;">No reference images available for this body part</p>' |
|
|
"</div>", |
|
|
visible=False, |
|
|
) |
|
|
gallery_head = gr.Gallery( |
|
|
columns=6, |
|
|
height=400, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
with gr.Tab("Left Flank"): |
|
|
left_flank_indicator = gr.Markdown("") |
|
|
left_flank_empty_message = gr.Markdown( |
|
|
value='<div style="text-align: center; padding: 60px 20px; color: #888;">' |
|
|
'<p style="font-size: 16px;">No reference images available for this body part</p>' |
|
|
"</div>", |
|
|
visible=False, |
|
|
) |
|
|
gallery_left_flank = gr.Gallery( |
|
|
columns=6, |
|
|
height=400, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
with gr.Tab("Right Flank"): |
|
|
right_flank_indicator = gr.Markdown("") |
|
|
right_flank_empty_message = gr.Markdown( |
|
|
value='<div style="text-align: center; padding: 60px 20px; color: #888;">' |
|
|
'<p style="font-size: 16px;">No reference images available for this body part</p>' |
|
|
"</div>", |
|
|
visible=False, |
|
|
) |
|
|
gallery_right_flank = gr.Gallery( |
|
|
columns=6, |
|
|
height=400, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
with gr.Tab("Tail"): |
|
|
tail_indicator = gr.Markdown("") |
|
|
tail_empty_message = gr.Markdown( |
|
|
value='<div style="text-align: center; padding: 60px 20px; color: #888;">' |
|
|
'<p style="font-size: 16px;">No reference images available for this body part</p>' |
|
|
"</div>", |
|
|
visible=False, |
|
|
) |
|
|
gallery_tail = gr.Gallery( |
|
|
columns=6, |
|
|
height=400, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
with gr.Tab("Other"): |
|
|
misc_indicator = gr.Markdown("") |
|
|
misc_empty_message = gr.Markdown( |
|
|
value='<div style="text-align: center; padding: 60px 20px; color: #888;">' |
|
|
'<p style="font-size: 16px;">No reference images available for this body part</p>' |
|
|
"</div>", |
|
|
visible=False, |
|
|
) |
|
|
gallery_misc = gr.Gallery( |
|
|
columns=6, |
|
|
height=400, |
|
|
object_fit="scale-down", |
|
|
allow_preview=True, |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=lambda img, ext, top_k, locs, parts, ex_path: run_identification( |
|
|
image=img, |
|
|
extractor=ext, |
|
|
top_k=int(top_k), |
|
|
selected_locations=locs, |
|
|
selected_body_parts=parts, |
|
|
example_path=ex_path, |
|
|
config=config, |
|
|
), |
|
|
inputs=[ |
|
|
image_input, |
|
|
extractor_dropdown, |
|
|
top_k_input, |
|
|
location_filter, |
|
|
body_part_filter, |
|
|
selected_example_state, |
|
|
], |
|
|
outputs=[ |
|
|
|
|
|
result_text, |
|
|
seg_viz, |
|
|
cropped_image, |
|
|
extracted_kpts_viz, |
|
|
matches_dataset, |
|
|
|
|
|
matched_kpts_viz, |
|
|
clean_comparison_viz, |
|
|
selected_match_header, |
|
|
head_indicator, |
|
|
left_flank_indicator, |
|
|
right_flank_indicator, |
|
|
tail_indicator, |
|
|
misc_indicator, |
|
|
head_empty_message, |
|
|
left_flank_empty_message, |
|
|
right_flank_empty_message, |
|
|
tail_empty_message, |
|
|
misc_empty_message, |
|
|
gallery_head, |
|
|
gallery_left_flank, |
|
|
gallery_right_flank, |
|
|
gallery_tail, |
|
|
gallery_misc, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
matches_dataset.select( |
|
|
fn=on_match_selected, |
|
|
outputs=[ |
|
|
matched_kpts_viz, |
|
|
clean_comparison_viz, |
|
|
selected_match_header, |
|
|
head_indicator, |
|
|
left_flank_indicator, |
|
|
right_flank_indicator, |
|
|
tail_indicator, |
|
|
misc_indicator, |
|
|
head_empty_message, |
|
|
left_flank_empty_message, |
|
|
right_flank_empty_message, |
|
|
tail_empty_message, |
|
|
misc_empty_message, |
|
|
gallery_head, |
|
|
gallery_left_flank, |
|
|
gallery_right_flank, |
|
|
gallery_tail, |
|
|
gallery_misc, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("Explore Catalog"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## Snow Leopard Catalog Browser |
|
|
|
|
|
Browse the reference catalog of known snow leopard individuals. |
|
|
Each individual has multiple reference images from different body parts and locations. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
stats = catalog_index.get("statistics", {}) |
|
|
formatted_locations = [ |
|
|
loc.replace("_", " ").title() for loc in stats.get("locations", []) |
|
|
] |
|
|
gr.Markdown( |
|
|
f""" |
|
|
### Catalog Statistics |
|
|
- **Total Individuals:** {stats.get("total_individuals", "N/A")} |
|
|
- **Total Images:** {stats.get("total_reference_images", "N/A")} |
|
|
- **Locations:** {", ".join(formatted_locations)} |
|
|
- **Body Parts:** {", ".join(stats.get("body_parts", []))} |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Individual Leopards by Location") |
|
|
|
|
|
|
|
|
individuals_by_location = {} |
|
|
for individual_data in individuals_data: |
|
|
location = individual_data.get("location", "unknown") |
|
|
if location not in individuals_by_location: |
|
|
individuals_by_location[location] = [] |
|
|
individuals_by_location[location].append(individual_data) |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
for location in sorted(individuals_by_location.keys()): |
|
|
with gr.Tab(f"{location.replace('_', ' ').title()}"): |
|
|
|
|
|
with gr.Tabs(): |
|
|
for leopard_data in individuals_by_location[location]: |
|
|
create_leopard_tab( |
|
|
leopard_metadata=leopard_data, config=config |
|
|
) |
|
|
|
|
|
|
|
|
app.unload(cleanup_temp_files) |
|
|
|
|
|
|
|
|
def load_first_example(): |
|
|
"""Load the first example image when the app starts. |
|
|
|
|
|
Returns both the image AND the path so the cache can be used |
|
|
when the user clicks Identify without selecting a new example. |
|
|
""" |
|
|
if example_images: |
|
|
try: |
|
|
first_image = Image.open(example_images[0]) |
|
|
first_path = str(example_images[0]) |
|
|
return first_image, first_path |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading first example image: {e}") |
|
|
return None, None |
|
|
return None, None |
|
|
|
|
|
app.load(fn=load_first_example, outputs=[image_input, selected_example_state]) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
ensure_data_extracted() |
|
|
|
|
|
|
|
|
logger.info("Checking for SAM HQ model...") |
|
|
sam_path = ensure_sam_model() |
|
|
|
|
|
|
|
|
if not CATALOG_ROOT.exists(): |
|
|
logger.error(f"Catalog not found: {CATALOG_ROOT}") |
|
|
logger.error("Please ensure catalog data is present in data/catalog/") |
|
|
exit(1) |
|
|
|
|
|
if not EXAMPLES_DIR.exists(): |
|
|
logger.warning(f"Examples directory not found: {EXAMPLES_DIR}") |
|
|
EXAMPLES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
config = AppConfig( |
|
|
model_path=None, |
|
|
catalog_root=CATALOG_ROOT, |
|
|
examples_dir=EXAMPLES_DIR, |
|
|
top_k=TOP_K_DEFAULT, |
|
|
port=7860, |
|
|
share=False, |
|
|
sam_checkpoint_path=sam_path, |
|
|
sam_model_type=SAM_MODEL_TYPE, |
|
|
gdino_model_id=GDINO_MODEL_ID, |
|
|
text_prompt=TEXT_PROMPT, |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Building Gradio interface...") |
|
|
app = create_app(config) |
|
|
|
|
|
logger.info("Launching app...") |
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
) |
|
|
|