achouffe's picture
fix: cache issue for first example load
4d5b14b verified
"""Gradio web application for snow leopard identification and catalog exploration.
This interactive web interface provides an easy-to-use frontend for the snow
leopard identification system. Users can upload images, view matches against the catalog,
and explore reference leopards through a browser-based UI powered by Gradio.
Features:
- Upload snow leopard images or select from examples
- Run full identification pipeline with GDINO+SAM segmentation
- View top-K matches with Wasserstein distance scores
- Explore complete leopard catalog with thumbnails
- Visualize matched keypoints between query and catalog images
Usage:
# Local testing with uv:
uv sync
uv run python app.py
# Deployed on Hugging Face Spaces
"""
import sys
from pathlib import Path
# Add src to path for imports BEFORE importing snowleopard_reid
SPACE_ROOT = Path(__file__).parent
sys.path.insert(0, str(SPACE_ROOT / "src"))
import logging
import shutil
import tempfile
from dataclasses import dataclass
import cv2
import gradio as gr
import numpy as np
import torch
import yaml
from huggingface_hub import hf_hub_download
from PIL import Image
from snowleopard_reid.cache import (
filter_cached_matches,
generate_visualizations_from_npz,
is_cached,
load_cached_results,
)
from snowleopard_reid.catalog import (
get_available_body_parts,
get_available_locations,
get_catalog_metadata_for_id,
load_catalog_index,
load_leopard_metadata,
)
from snowleopard_reid.data_setup import ensure_data_extracted
from snowleopard_reid.pipeline.stages import (
run_feature_extraction_stage,
run_matching_stage,
run_preprocess_stage,
run_segmentation_stage,
select_best_mask,
)
from snowleopard_reid.pipeline.stages.segmentation import (
load_gdino_model,
load_sam_predictor,
)
from snowleopard_reid.visualization import (
draw_keypoints_overlay,
draw_matched_keypoints,
draw_side_by_side_comparison,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Configuration (hardcoded for HF Spaces / local dev)
CATALOG_ROOT = SPACE_ROOT / "data" / "catalog"
SAM_CHECKPOINT_DIR = SPACE_ROOT / "data" / "models"
SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth"
EXAMPLES_DIR = SPACE_ROOT / "data" / "examples"
GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
TEXT_PROMPT = "a snow leopard."
TOP_K_DEFAULT = 5
SAM_MODEL_TYPE = "vit_l"
@dataclass
class AppConfig:
"""Configuration for the Snow Leopard ID UI application."""
model_path: Path | None
catalog_root: Path
examples_dir: Path
top_k: int
port: int
share: bool
# GDINO+SAM parameters
sam_checkpoint_path: Path
sam_model_type: str
gdino_model_id: str
text_prompt: str
def ensure_sam_model() -> Path:
"""Download SAM HQ model if not present.
Returns:
Path to the SAM HQ checkpoint file
"""
sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME
if not sam_path.exists():
logger.info("Downloading SAM HQ model (1.6GB)...")
SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
hf_hub_download(
repo_id="lkeab/hq-sam",
filename=SAM_CHECKPOINT_NAME,
local_dir=SAM_CHECKPOINT_DIR,
)
logger.info("SAM HQ model downloaded successfully")
return sam_path
def get_available_extractors(catalog_root: Path) -> list[str]:
"""Get list of available feature extractors from catalog.
Args:
catalog_root: Root directory of the leopard catalog
Returns:
List of available extractor names (e.g., ['sift', 'superpoint'])
"""
try:
catalog_index = load_catalog_index(catalog_root)
extractors = list(catalog_index.get("feature_extractors", {}).keys())
if not extractors:
logger.warning(f"No extractors found in catalog at {catalog_root}")
return ["sift"] # Default fallback
return extractors
except Exception as e:
logger.error(f"Failed to load catalog index: {e}")
return ["sift"] # Default fallback
# Global state for models and catalog (loaded at startup)
LOADED_MODELS = {}
def load_catalog_data(config: AppConfig):
"""Load catalog index and individual leopard metadata.
Args:
config: Application configuration containing catalog_root
Returns:
Tuple of (catalog_index, individuals_data)
"""
catalog_index_path = config.catalog_root / "catalog_index.yaml"
# Load catalog index
with open(catalog_index_path) as f:
catalog_index = yaml.safe_load(f)
# Load metadata for each individual
individuals_data = []
for individual in catalog_index["individuals"]:
metadata_path = config.catalog_root / individual["metadata_path"]
with open(metadata_path) as f:
leopard_metadata = yaml.safe_load(f)
individuals_data.append(leopard_metadata)
return catalog_index, individuals_data
def initialize_models(config: AppConfig):
"""Load models at startup for faster inference.
Args:
config: Application configuration containing model paths
"""
logger.info("Initializing models...")
# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
if device == "cuda":
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")
# Load Grounding DINO model
logger.info(f"Loading Grounding DINO model: {config.gdino_model_id}")
gdino_processor, gdino_model = load_gdino_model(
model_id=config.gdino_model_id,
device=device,
)
LOADED_MODELS["gdino_processor"] = gdino_processor
LOADED_MODELS["gdino_model"] = gdino_model
logger.info("Grounding DINO model loaded successfully")
# Load SAM HQ model
logger.info(
f"Loading SAM HQ model from {config.sam_checkpoint_path} (type: {config.sam_model_type})"
)
sam_predictor = load_sam_predictor(
checkpoint_path=config.sam_checkpoint_path,
model_type=config.sam_model_type,
device=device,
)
LOADED_MODELS["sam_predictor"] = sam_predictor
logger.info("SAM HQ model loaded successfully")
# Store device info and catalog root for callbacks
LOADED_MODELS["device"] = device
LOADED_MODELS["catalog_root"] = config.catalog_root
LOADED_MODELS["text_prompt"] = config.text_prompt
logger.info("Models initialized successfully")
def _load_from_cache(
example_path: str,
extractor: str,
config: "AppConfig",
filter_locations: list[str] | None = None,
filter_body_parts: list[str] | None = None,
top_k: int = 5,
):
"""Load cached pipeline results with optional filtering and return UI component updates.
Supports the v2.0 cache format which stores ALL matches with location/body_part
metadata, enabling client-side filtering without re-running the pipeline.
Args:
example_path: Path to the example image
extractor: Feature extractor name
config: Application configuration
filter_locations: Optional list of locations to filter by
filter_body_parts: Optional list of body parts to filter by
top_k: Number of top matches to return after filtering
Returns:
Tuple of 23 UI components matching run_identification output
"""
# Load cached results
cached = load_cached_results(example_path, extractor)
predictions = cached["predictions"]
# Support both v1.0 ("matches") and v2.0 ("all_matches") cache formats
if "all_matches" in predictions:
all_matches = predictions["all_matches"]
else:
# Fallback for v1.0 cache format (no filtering support)
all_matches = predictions.get("matches", [])
# Filter and re-rank matches
matches = filter_cached_matches(
all_matches=all_matches,
filter_locations=filter_locations,
filter_body_parts=filter_body_parts,
top_k=top_k,
)
if not matches:
# No matches after filtering - return empty results
return (
"No matches found with the selected filters",
cached["segmentation_image"],
cached["cropped_image"],
cached["keypoints_image"],
[],
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(value=[]),
gr.update(value=[]),
gr.update(value=[]),
gr.update(value=[]),
gr.update(value=[]),
)
# Generate visualizations on-demand from NPZ data
logger.info(f"Generating visualizations for {len(matches)} filtered matches...")
match_visualizations, clean_comparison_visualizations = (
generate_visualizations_from_npz(
pairwise_dir=cached["pairwise_dir"],
matches=matches,
cropped_image_path=cached["pairwise_dir"].parent / "cropped.png",
)
)
# Store in global state for match selection
LOADED_MODELS["current_match_visualizations"] = match_visualizations
LOADED_MODELS["current_clean_comparison_visualizations"] = (
clean_comparison_visualizations
)
LOADED_MODELS["current_enriched_matches"] = matches
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
LOADED_MODELS["current_temp_dir"] = None # No temp dir for cached results
# Top match info for result text
top_match = matches[0]
top_leopard_name = top_match["leopard_name"]
top_wasserstein = top_match["wasserstein"]
# Determine confidence level
if top_wasserstein >= 0.12:
confidence_indicator = "🔵" # Excellent
elif top_wasserstein >= 0.07:
confidence_indicator = "🟢" # Good
elif top_wasserstein >= 0.04:
confidence_indicator = "🟡" # Fair
else:
confidence_indicator = "🔴" # Uncertain
result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
# Build dataset for top-K matches table
dataset_samples = []
for match in matches:
rank = match["rank"]
leopard_name = match["leopard_name"]
wasserstein = match["wasserstein"]
# Use location from cache (v2.0) or extract from path
location = match.get("location", "unknown")
if location == "unknown":
catalog_id = match["catalog_id"]
catalog_metadata = get_catalog_metadata_for_id(
config.catalog_root, catalog_id
)
if catalog_metadata:
img_path_parts = Path(catalog_metadata["image_path"]).parts
try:
db_idx = img_path_parts.index("database")
if db_idx + 1 < len(img_path_parts):
location = img_path_parts[db_idx + 1]
except ValueError:
pass
# Confidence indicator
if wasserstein >= 0.12:
indicator = "🔵"
elif wasserstein >= 0.07:
indicator = "🟢"
elif wasserstein >= 0.04:
indicator = "🟡"
else:
indicator = "🔴"
dataset_samples.append(
[
rank,
indicator,
leopard_name.title(),
location.replace("_", " ").title(),
f"{wasserstein:.4f}",
]
)
# Load rank 1 details
rank1_details = load_match_details_for_rank(rank=1)
# Return all 23 outputs
return (
result_text, # 1. Top match result text
cached["segmentation_image"], # 2. Segmentation overlay
cached["cropped_image"], # 3. Cropped leopard
cached["keypoints_image"], # 4. Extracted keypoints
dataset_samples, # 5. Matches table data
*rank1_details, # 6-23. visualizations, header, indicators, galleries
)
def run_identification(
image,
extractor: str,
top_k: int,
selected_locations: list[str],
selected_body_parts: list[str],
example_path: str | None,
config: AppConfig,
):
"""Run snow leopard identification pipeline on uploaded image.
Args:
image: PIL Image from Gradio upload
extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked')
top_k: Number of top matches to return
selected_locations: List of selected locations (includes "all" for no filtering)
selected_body_parts: List of selected body parts (includes "all" for no filtering)
example_path: Path to example image if selected from examples (for cache lookup)
config: Application configuration
Returns:
Tuple of UI components to update
"""
if image is None:
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
return (
"Please upload an image first", # 1. result_text
None, # 2. seg_viz
None, # 3. cropped_image
None, # 4. extracted_kpts_viz
[], # 5. dataset_samples
gr.update(value=None), # 6. matched_kpts_viz
gr.update(value=None), # 7. clean_comparison_viz
gr.update(value=""), # 8. header
gr.update(value=""), # 9. head indicator
gr.update(value=""), # 10. left_flank indicator
gr.update(value=""), # 11. right_flank indicator
gr.update(value=""), # 12. tail indicator
gr.update(value=""), # 13. misc indicator
gr.update(visible=False), # 14. head empty message
gr.update(visible=False), # 15. left_flank empty message
gr.update(visible=False), # 16. right_flank empty message
gr.update(visible=False), # 17. tail empty message
gr.update(visible=False), # 18. misc empty message
gr.update(value=[]), # 19. head gallery
gr.update(value=[]), # 20. left_flank gallery
gr.update(value=[]), # 21. right_flank gallery
gr.update(value=[]), # 22. tail gallery
gr.update(value=[]), # 23. misc gallery
)
# Convert filter selections to None if "all" is selected
filter_locations = (
None
if not selected_locations or "all" in selected_locations
else selected_locations
)
filter_body_parts_parsed = (
None
if not selected_body_parts or "all" in selected_body_parts
else selected_body_parts
)
# Debug logging for cache check
logger.info(f"Cache check: example_path={example_path}, extractor={extractor}")
if example_path:
cache_exists = is_cached(example_path, extractor)
logger.info(f"is_cached() returned: {cache_exists}")
else:
cache_exists = False
logger.info("No example_path provided, skipping cache")
# Check cache for example images (v2.0 cache supports filtering)
if example_path and cache_exists:
logger.info(f"Cache hit for {example_path} with {extractor}")
if filter_locations or filter_body_parts_parsed:
logger.info(
f" Applying filters: locations={filter_locations}, body_parts={filter_body_parts_parsed}"
)
try:
return _load_from_cache(
example_path,
extractor,
config,
filter_locations=filter_locations,
filter_body_parts=filter_body_parts_parsed,
top_k=int(top_k),
)
except Exception as e:
logger.warning(f"Cache load failed, running pipeline: {e}")
# Fall through to run full pipeline
# Use the already-parsed filter values for the pipeline
filter_body_parts = filter_body_parts_parsed
# Log applied filters
if filter_locations or filter_body_parts:
filter_desc = []
if filter_locations:
filter_desc.append(f"locations: {', '.join(filter_locations)}")
if filter_body_parts:
filter_desc.append(f"body parts: {', '.join(filter_body_parts)}")
logger.info(f"Applied filters - {' | '.join(filter_desc)}")
else:
logger.info("No filters applied - matching against entire catalog")
try:
# Create temporary directory for this query
temp_dir = Path(tempfile.mkdtemp(prefix="snowleopard_id_"))
temp_image_path = temp_dir / "query.jpg"
# Save uploaded image
logger.info(f"Image type: {type(image)}")
logger.info(f"Image mode: {image.mode if hasattr(image, 'mode') else 'N/A'}")
logger.info(f"Image size: {image.size if hasattr(image, 'size') else 'N/A'}")
image.save(temp_image_path, quality=95)
# Verify saved image
saved_size = temp_image_path.stat().st_size
logger.info(f"Saved image size: {saved_size / 1024 / 1024:.2f} MB")
logger.info(f"Processing query image: {temp_image_path}")
device = LOADED_MODELS.get("device", "cpu")
# Step 1: Run GDINO+SAM segmentation using pre-loaded models
logger.info("Running GDINO+SAM segmentation...")
gdino_processor = LOADED_MODELS.get("gdino_processor")
gdino_model = LOADED_MODELS.get("gdino_model")
sam_predictor = LOADED_MODELS.get("sam_predictor")
text_prompt = LOADED_MODELS.get("text_prompt", "a snow leopard.")
seg_stage = run_segmentation_stage(
image_path=temp_image_path,
strategy="gdino_sam",
confidence_threshold=0.2,
device=device,
gdino_processor=gdino_processor,
gdino_model=gdino_model,
sam_predictor=sam_predictor,
text_prompt=text_prompt,
box_threshold=0.30,
text_threshold=0.20,
)
predictions = seg_stage["data"]["predictions"]
logger.info(f"Number of predictions: {len(predictions)}")
if not predictions:
logger.warning("No predictions found from segmentation")
logger.warning(f"Full segmentation stage: {seg_stage}")
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
return (
"No snow leopards detected in image", # 1. result_text
None, # 2. seg_viz
None, # 3. cropped_image
None, # 4. extracted_kpts_viz
[], # 5. dataset_samples
gr.update(value=None), # 6. matched_kpts_viz
gr.update(value=None), # 7. clean_comparison_viz
gr.update(value=""), # 8. header
gr.update(value=""), # 9. head indicator
gr.update(value=""), # 10. left_flank indicator
gr.update(value=""), # 11. right_flank indicator
gr.update(value=""), # 12. tail indicator
gr.update(value=""), # 13. misc indicator
gr.update(visible=False), # 14. head empty message
gr.update(visible=False), # 15. left_flank empty message
gr.update(visible=False), # 16. right_flank empty message
gr.update(visible=False), # 17. tail empty message
gr.update(visible=False), # 18. misc empty message
gr.update(value=[]), # 19. head gallery
gr.update(value=[]), # 20. left_flank gallery
gr.update(value=[]), # 21. right_flank gallery
gr.update(value=[]), # 22. tail gallery
gr.update(value=[]), # 23. misc gallery
)
# Step 2: Select best mask
logger.info("Selecting best mask...")
selected_idx, selected_pred = select_best_mask(
predictions,
strategy="confidence_area",
)
# Step 3: Preprocess (crop and mask)
logger.info("Preprocessing query image...")
prep_stage = run_preprocess_stage(
image_path=temp_image_path,
mask=selected_pred["mask"],
padding=5,
)
cropped_image_pil = prep_stage["data"]["cropped_image"]
# Save cropped image for visualization later
cropped_path = temp_dir / "cropped.jpg"
cropped_image_pil.save(cropped_path)
# Step 4: Extract features
logger.info(f"Extracting features using {extractor.upper()}...")
feat_stage = run_feature_extraction_stage(
image=cropped_image_pil,
extractor=extractor,
max_keypoints=2048,
device=device,
)
query_features = feat_stage["data"]["features"]
# Step 5: Match against catalog
logger.info("Matching against catalog...")
pairwise_dir = temp_dir / "pairwise"
pairwise_dir.mkdir(exist_ok=True)
match_stage = run_matching_stage(
query_features=query_features,
catalog_path=config.catalog_root,
top_k=top_k,
extractor=extractor,
device=device,
query_image_path=str(cropped_path),
pairwise_output_dir=pairwise_dir,
filter_locations=filter_locations,
filter_body_parts=filter_body_parts,
)
matches = match_stage["data"]["matches"]
if not matches:
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
return (
"No matches found in catalog", # 1. result_text
None, # 2. seg_viz
cropped_image_pil, # 3. cropped_image
None, # 4. extracted_kpts_viz
[], # 5. dataset_samples
gr.update(value=None), # 6. matched_kpts_viz
gr.update(value=None), # 7. clean_comparison_viz
gr.update(value=""), # 8. header
gr.update(value=""), # 9. head indicator
gr.update(value=""), # 10. left_flank indicator
gr.update(value=""), # 11. right_flank indicator
gr.update(value=""), # 12. tail indicator
gr.update(value=""), # 13. misc indicator
gr.update(visible=False), # 14. head empty message
gr.update(visible=False), # 15. left_flank empty message
gr.update(visible=False), # 16. right_flank empty message
gr.update(visible=False), # 17. tail empty message
gr.update(visible=False), # 18. misc empty message
gr.update(value=[]), # 19. head gallery
gr.update(value=[]), # 20. left_flank gallery
gr.update(value=[]), # 21. right_flank gallery
gr.update(value=[]), # 22. tail gallery
gr.update(value=[]), # 23. misc gallery
)
# Top match
top_match = matches[0]
top_leopard_name = top_match["leopard_name"]
top_wasserstein = top_match["wasserstein"]
# Determine confidence level (higher Wasserstein = better match)
if top_wasserstein >= 0.12:
confidence_indicator = "🔵" # Excellent
elif top_wasserstein >= 0.07:
confidence_indicator = "🟢" # Good
elif top_wasserstein >= 0.04:
confidence_indicator = "🟡" # Fair
else:
confidence_indicator = "🔴" # Uncertain
result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
# Create segmentation visualization
seg_viz = create_segmentation_viz(
image_path=temp_image_path, mask=selected_pred["mask"]
)
# Generate extracted keypoints visualization
extracted_kpts_viz = None
try:
# Extract keypoints from query features for visualization
query_kpts = query_features["keypoints"].cpu().numpy()
extracted_kpts_viz = draw_keypoints_overlay(
image_path=cropped_path,
keypoints=query_kpts,
max_keypoints=500,
color="blue",
ps=10,
)
except Exception as e:
logger.error(f"Error creating extracted keypoints visualization: {e}")
# Build dataset for top-K matches table
dataset_samples = []
match_visualizations = {}
clean_comparison_visualizations = {}
for match in matches:
rank = match["rank"]
leopard_name = match["leopard_name"]
wasserstein = match["wasserstein"]
catalog_img_path = Path(match["filepath"])
# Get location from catalog metadata
catalog_id = match["catalog_id"]
catalog_metadata = get_catalog_metadata_for_id(
config.catalog_root, catalog_id
)
location = "unknown"
if catalog_metadata:
# Extract location from path: database/{location}/{individual}/...
img_path_parts = Path(catalog_metadata["image_path"]).parts
if len(img_path_parts) >= 3:
# Find 'database' in path and get next part
try:
db_idx = img_path_parts.index("database")
if db_idx + 1 < len(img_path_parts):
location = img_path_parts[db_idx + 1]
except ValueError:
pass
# Confidence indicator (higher Wasserstein = better match)
if wasserstein >= 0.12:
indicator = "🔵" # Excellent
elif wasserstein >= 0.07:
indicator = "🟢" # Good
elif wasserstein >= 0.04:
indicator = "🟡" # Fair
else:
indicator = "🔴" # Uncertain
# Create visualizations for this match
npz_path = pairwise_dir / f"rank_{rank:02d}_{match['catalog_id']}.npz"
if npz_path.exists():
try:
pairwise_data = np.load(npz_path)
# Create matched keypoints visualization
match_viz = draw_matched_keypoints(
query_image_path=cropped_path,
catalog_image_path=catalog_img_path,
query_keypoints=pairwise_data["query_keypoints"],
catalog_keypoints=pairwise_data["catalog_keypoints"],
match_scores=pairwise_data["match_scores"],
max_matches=100,
)
match_visualizations[rank] = match_viz
# Create clean comparison visualization
clean_viz = draw_side_by_side_comparison(
query_image_path=cropped_path,
catalog_image_path=catalog_img_path,
)
clean_comparison_visualizations[rank] = clean_viz
except Exception as e:
logger.error(f"Error creating visualizations for rank {rank}: {e}")
# Format for table (as list, not dict)
dataset_samples.append(
[
rank,
indicator,
leopard_name.title(),
location.replace("_", " ").title(),
f"{wasserstein:.4f}",
]
)
# Store match visualizations, enriched matches, filters, and temp_dir in global state
LOADED_MODELS["current_match_visualizations"] = match_visualizations
LOADED_MODELS["current_clean_comparison_visualizations"] = (
clean_comparison_visualizations
)
LOADED_MODELS["current_enriched_matches"] = matches
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
LOADED_MODELS["current_temp_dir"] = temp_dir
# Automatically load rank 1 details (visualizations + galleries)
rank1_details = load_match_details_for_rank(rank=1)
# Return 23 outputs total:
# - 5 pipeline outputs (result_text, seg_viz, cropped_image, extracted_kpts_viz, dataset_samples)
# - 18 rank 1 details (from load_match_details_for_rank)
return (
result_text, # 1. Top match result text
seg_viz, # 2. Segmentation overlay
cropped_image_pil, # 3. Cropped leopard
extracted_kpts_viz, # 4. Extracted keypoints
dataset_samples, # 5. Matches table data
# Unpack all 18 rank 1 details:
*rank1_details, # 6-23. visualizations, header, indicators, galleries
)
except Exception as e:
logger.error(f"Error processing image: {e}", exc_info=True)
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
return (
f"Error processing image: {str(e)}", # 1. result_text
None, # 2. seg_viz
None, # 3. cropped_image
None, # 4. extracted_kpts_viz
[], # 5. dataset_samples
gr.update(value=None), # 6. matched_kpts_viz
gr.update(value=None), # 7. clean_comparison_viz
gr.update(value=""), # 8. header
gr.update(value=""), # 9. head indicator
gr.update(value=""), # 10. left_flank indicator
gr.update(value=""), # 11. right_flank indicator
gr.update(value=""), # 12. tail indicator
gr.update(value=""), # 13. misc indicator
gr.update(visible=False), # 14. head empty message
gr.update(visible=False), # 15. left_flank empty message
gr.update(visible=False), # 16. right_flank empty message
gr.update(visible=False), # 17. tail empty message
gr.update(visible=False), # 18. misc empty message
gr.update(value=[]), # 19. head gallery
gr.update(value=[]), # 20. left_flank gallery
gr.update(value=[]), # 21. right_flank gallery
gr.update(value=[]), # 22. tail gallery
gr.update(value=[]), # 23. misc gallery
)
def create_segmentation_viz(image_path, mask):
"""Create visualization of segmentation mask overlaid on image."""
# Load original image
img = cv2.imread(str(image_path))
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Resize mask to match image dimensions if needed
if mask.shape[:2] != img_rgb.shape[:2]:
mask_resized = cv2.resize(
mask.astype(np.uint8),
(img_rgb.shape[1], img_rgb.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
else:
mask_resized = mask
# Create colored overlay
overlay = img_rgb.copy()
overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region
# Blend
alpha = 0.4
blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0)
return Image.fromarray(blended)
def load_match_details_for_rank(rank: int) -> tuple:
"""Load all match details (visualizations + galleries) for a specific rank.
This is a reusable helper function that encapsulates the logic for loading
match visualizations, galleries, and metadata for a given rank. Used by both
the automatic rank 1 display after pipeline completion and the interactive
row selection handler.
Args:
rank: The rank to load (1-indexed)
Returns:
Tuple of 18 Gradio component updates:
(matched_kpts_viz, clean_comparison_viz, header,
head_indicator, left_flank_indicator, right_flank_indicator, tail_indicator, misc_indicator,
head_empty_message, left_flank_empty_message, right_flank_empty_message,
tail_empty_message, misc_empty_message,
gallery_head, gallery_left_flank, gallery_right_flank, gallery_tail, gallery_misc)
"""
# Get stored data from global state
match_visualizations = LOADED_MODELS.get("current_match_visualizations", {})
clean_comparison_visualizations = LOADED_MODELS.get(
"current_clean_comparison_visualizations", {}
)
enriched_matches = LOADED_MODELS.get("current_enriched_matches", [])
filter_body_parts = LOADED_MODELS.get("current_filter_body_parts")
catalog_root = LOADED_MODELS.get("catalog_root")
# Find the match for the requested rank
selected_match = None
for match in enriched_matches:
if match["rank"] == rank:
selected_match = match
break
if not selected_match or rank not in match_visualizations:
# Return empty updates for all 18 outputs
return (
gr.update(value=None), # 1. matched_kpts_viz
gr.update(value=None), # 2. clean_comparison_viz
gr.update(value=""), # 3. header
gr.update(value=""), # 4. head indicator
gr.update(value=""), # 5. left_flank indicator
gr.update(value=""), # 6. right_flank indicator
gr.update(value=""), # 7. tail indicator
gr.update(value=""), # 8. misc indicator
gr.update(visible=False), # 9. head empty message
gr.update(visible=False), # 10. left_flank empty message
gr.update(visible=False), # 11. right_flank empty message
gr.update(visible=False), # 12. tail empty message
gr.update(visible=False), # 13. misc empty message
gr.update(value=[]), # 14. head gallery
gr.update(value=[]), # 15. left_flank gallery
gr.update(value=[]), # 16. right_flank gallery
gr.update(value=[]), # 17. tail gallery
gr.update(value=[]), # 18. misc gallery
)
# Get both visualizations
match_viz = match_visualizations[rank]
clean_viz = clean_comparison_visualizations.get(rank)
# Create dynamic header with leopard name
leopard_name = selected_match["leopard_name"]
header_text = f"## Reference Images for {leopard_name.title()}"
# Load galleries organized by body part
galleries = {}
if catalog_root:
try:
# Extract location from match filepath
location = None
filepath = Path(selected_match["filepath"])
parts = filepath.parts
if "database" in parts:
db_idx = parts.index("database")
if db_idx + 1 < len(parts):
location = parts[db_idx + 1]
galleries = load_matched_individual_gallery_by_body_part(
catalog_root=catalog_root,
leopard_name=leopard_name,
location=location,
)
except Exception as e:
logger.error(f"Error loading gallery for {leopard_name}: {e}")
# Initialize empty galleries on error
galleries = {
"head": [],
"left_flank": [],
"right_flank": [],
"tail": [],
"misc": [],
}
# Create emoji indicators for filtered body parts
def get_indicator(body_part: str) -> str:
"""Return star if body part was in filter, empty string otherwise."""
if filter_body_parts and body_part in filter_body_parts:
return "* (filtered)"
return ""
# Helper to determine if empty message should be visible
def is_empty(body_part: str) -> bool:
"""Return True if no images for this body part."""
return len(galleries.get(body_part, [])) == 0
return (
gr.update(value=match_viz), # 1. matched_kpts_viz
gr.update(value=clean_viz), # 2. clean_comparison_viz
gr.update(value=header_text), # 3. header
gr.update(value=get_indicator("head")), # 4. head indicator
gr.update(value=get_indicator("left_flank")), # 5. left_flank indicator
gr.update(value=get_indicator("right_flank")), # 6. right_flank indicator
gr.update(value=get_indicator("tail")), # 7. tail indicator
gr.update(value=get_indicator("misc")), # 8. misc indicator
gr.update(visible=is_empty("head")), # 9. head empty message
gr.update(visible=is_empty("left_flank")), # 10. left_flank empty message
gr.update(visible=is_empty("right_flank")), # 11. right_flank empty message
gr.update(visible=is_empty("tail")), # 12. tail empty message
gr.update(visible=is_empty("misc")), # 13. misc empty message
gr.update(
value=galleries.get("head", []), visible=not is_empty("head")
), # 14. head gallery
gr.update(
value=galleries.get("left_flank", []), visible=not is_empty("left_flank")
), # 15. left_flank gallery
gr.update(
value=galleries.get("right_flank", []), visible=not is_empty("right_flank")
), # 16. right_flank gallery
gr.update(
value=galleries.get("tail", []), visible=not is_empty("tail")
), # 17. tail gallery
gr.update(
value=galleries.get("misc", []), visible=not is_empty("misc")
), # 18. misc gallery
)
def on_match_selected(evt: gr.SelectData):
"""Handle selection of a match from the dataset table.
Returns both visualizations, header, indicators, empty messages,
and galleries organized by body part.
"""
# evt.index is [row, col] for Dataframe, we want row
if isinstance(evt.index, (list, tuple)):
selected_row = evt.index[0]
else:
selected_row = evt.index
selected_rank = selected_row + 1 # Ranks are 1-indexed
# Delegate to the reusable helper function
return load_match_details_for_rank(selected_rank)
def load_matched_individual_gallery_by_body_part(
catalog_root: Path,
leopard_name: str,
location: str | None = None,
) -> dict[str, list[tuple]]:
"""Load all images for a matched individual organized by body part.
Args:
catalog_root: Path to catalog root directory
leopard_name: Name of the matched individual (e.g., "karindas")
location: Geographic location (e.g., "skycrest_valley")
Returns:
Dict mapping body part to list of (PIL.Image, caption) tuples:
{
"head": [(img1, caption1), (img2, caption2), ...],
"left_flank": [...],
"right_flank": [...],
"tail": [...],
"misc": [...]
}
"""
# Initialize dict with all body parts
galleries = {
"head": [],
"left_flank": [],
"right_flank": [],
"tail": [],
"misc": [],
}
# Find metadata path: database/{location}/{individual}/metadata.yaml
if location:
metadata_path = (
catalog_root / "database" / location / leopard_name / "metadata.yaml"
)
else:
# Try to find the individual in any location
metadata_path = None
database_dir = catalog_root / "database"
if database_dir.exists():
for loc_dir in database_dir.iterdir():
if loc_dir.is_dir():
potential_path = loc_dir / leopard_name / "metadata.yaml"
if potential_path.exists():
metadata_path = potential_path
break
if not metadata_path or not metadata_path.exists():
logger.warning(f"Metadata not found for {leopard_name}")
return galleries
try:
metadata = load_leopard_metadata(metadata_path)
# Load all images organized by body part
for img_entry in metadata["reference_images"]:
body_part = img_entry.get("body_part", "misc")
# Normalize body_part to match our keys
if body_part not in galleries:
body_part = "misc" # Default to misc if unknown
# Load image
img_path = catalog_root / "database" / img_entry["path"]
try:
img = Image.open(img_path)
# Simple caption: just body part name
caption = body_part
galleries[body_part].append((img, caption))
except Exception as e:
logger.error(f"Error loading image {img_path}: {e}")
except Exception as e:
logger.error(f"Error loading metadata for {leopard_name}: {e}")
return galleries
def cleanup_temp_files():
"""Clean up temporary files from previous run."""
temp_dir = LOADED_MODELS.get("current_temp_dir")
if temp_dir and temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info(f"Cleaned up temporary directory: {temp_dir}")
except Exception as e:
logger.warning(f"Error cleaning up temp directory: {e}")
def create_leopard_tab(leopard_metadata, config: AppConfig):
"""Create a tab for displaying a single leopard's images.
Args:
leopard_metadata: Metadata dictionary for the leopard individual
config: Application configuration
"""
# Support both 'leopard_name' and 'individual_name' keys
leopard_name = leopard_metadata.get("leopard_name") or leopard_metadata.get(
"individual_name"
)
location = leopard_metadata.get("location", "unknown")
total_images = leopard_metadata["statistics"]["total_reference_images"]
# Get body parts from statistics
body_parts = leopard_metadata["statistics"].get(
"body_parts_represented", leopard_metadata["statistics"].get("body_parts", [])
)
body_parts_str = ", ".join(body_parts) if body_parts else "N/A"
with gr.Tab(f"{leopard_name}"):
# Header with statistics
gr.Markdown(
f"### {leopard_name.title()}\n"
f"**Location:** {location.replace('_', ' ').title()} | "
f"**{total_images} images** | "
f"**Body parts:** {body_parts_str}"
)
# Load all images with body_part captions
gallery_data = []
for img_entry in leopard_metadata["reference_images"]:
img_path = config.catalog_root / "database" / img_entry["path"]
body_part = img_entry.get("body_part", "unknown")
try:
img = Image.open(img_path)
# Caption format: just body_part (location is already in tab)
caption = body_part
gallery_data.append((img, caption))
except Exception as e:
logger.error(f"Error loading image {img_path}: {e}")
# Display gallery
gr.Gallery(
value=gallery_data,
label=f"Reference Images for {leopard_name.title()}",
columns=6,
height=700,
object_fit="scale-down",
allow_preview=True,
)
def create_app(config: AppConfig):
"""Create and configure the Gradio application.
Args:
config: Application configuration
"""
# Initialize models at startup
initialize_models(config)
# Load catalog data
catalog_index, individuals_data = load_catalog_data(config)
# Build example images list from examples directory
example_images = (
list(config.examples_dir.glob("*.jpg"))
+ list(config.examples_dir.glob("*.JPG"))
+ list(config.examples_dir.glob("*.png"))
)
# Sort with Ayima images last
example_images.sort(key=lambda x: (1 if "Ayima" in x.name else 0, x.name))
# Create interface
with gr.Blocks(title="Snow Leopard Identification") as app:
# Hidden state to track which example image was selected (for cache lookup)
selected_example_state = gr.State(value=None)
# Main tabs
with gr.Tabs():
# Tab 1: Identify Snow Leopard
with gr.Tab("Identify Snow Leopard"):
gr.Markdown("""
Upload a snow leopard image or select an example to identify which individual it is.
The system will detect the leopard, extract distinctive features, and match against the catalog.
""")
with gr.Row():
# Left column: Input
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Snow Leopard Image",
sources=["upload", "clipboard"],
)
examples_component = gr.Examples(
examples=[[str(img)] for img in example_images],
inputs=image_input,
label="Example Images",
)
# Track example selection for cache lookup
def on_example_select(evt: gr.SelectData):
"""Update state when an example is selected."""
if evt.index is not None:
return str(example_images[evt.index])
return None
# When image changes, check if it matches an example
def check_if_example(img):
"""Check if uploaded image matches an example path."""
# When user uploads a new image, clear the example state
# Examples component handles setting state via select event
return gr.update() # No change to state on image change
examples_component.dataset.select(
fn=on_example_select,
outputs=[selected_example_state],
)
# Clear example state when user uploads a new image
image_input.upload(
fn=lambda: None,
outputs=[selected_example_state],
)
# Location filter dropdown
available_locations = get_available_locations(
config.catalog_root
)
location_filter = gr.Dropdown(
choices=available_locations,
value=["all"],
multiselect=True,
label="Filter by Location",
info="Select locations to search (default: all locations)",
)
# Body part filter dropdown
available_body_parts = get_available_body_parts(
config.catalog_root
)
body_part_filter = gr.Dropdown(
choices=available_body_parts,
value=["all"],
multiselect=True,
label="Filter by Body Part",
info="Select body parts to match (default: all body parts)",
)
# Advanced Configuration Accordion
with gr.Accordion("Advanced Configuration", open=False):
# Feature extractor dropdown
available_extractors = get_available_extractors(
config.catalog_root
)
extractor_dropdown = gr.Dropdown(
choices=available_extractors,
value="sift"
if "sift" in available_extractors
else (
available_extractors[0]
if available_extractors
else "sift"
),
label="Feature Extractor",
info=f"Available: {', '.join(available_extractors)}",
scale=1,
)
# Top-K parameter
top_k_input = gr.Number(
value=config.top_k,
label="Top-K Matches",
info="Number of top matches to return",
minimum=1,
maximum=20,
step=1,
precision=0,
scale=1,
)
submit_btn = gr.Button(
value="Identify Snow Leopard",
variant="primary",
size="lg",
)
# Right column: Results
with gr.Column(scale=4):
# Top-1 prediction
result_text = gr.Markdown("")
# Tabs for different result views
with gr.Tabs():
with gr.Tab("Model Internals"):
gr.Markdown("""
View the internal processing steps: segmentation mask, cropped leopard, and extracted keypoints.
""")
with gr.Row():
seg_viz = gr.Image(
label="Segmentation Overlay",
type="pil",
)
cropped_image = gr.Image(
label="Extracted Snow Leopard",
type="pil",
)
extracted_kpts_viz = gr.Image(
label="Extracted Keypoints",
type="pil",
)
with gr.Tab("Top Matches"):
gr.Markdown("""
Click a row to view detailed feature matching visualization and all reference images for that leopard.
**Higher Wasserstein distance = better match** (typical range: 0.04-0.27)
**Confidence Levels:** 🔵 Excellent (>=0.12) | 🟢 Good (>=0.07) | 🟡 Fair (>=0.04) | 🔴 Uncertain (<0.04)
""")
matches_dataset = gr.Dataframe(
headers=[
"Rank",
"Confidence",
"Leopard Name",
"Location",
"Wasserstein",
],
label="Top Matches",
wrap=True,
col_count=(5, "fixed"),
)
# Visualization container (always visible, images populated on pipeline completion)
with gr.Column() as viz_tabs:
# Tabbed visualization views
with gr.Tabs():
with gr.Tab("Matched Keypoints"):
gr.Markdown(
"Feature matching with keypoints and confidence-coded connecting lines. "
"**Green** = high confidence, **Yellow** = medium, **Red** = low."
)
matched_kpts_viz = gr.Image(
type="pil",
show_label=False,
)
with gr.Tab("Clean Comparison"):
gr.Markdown(
"Side-by-side comparison without feature annotations. "
"Useful for assessing overall visual similarity and spotting patterns."
)
clean_comparison_viz = gr.Image(
type="pil",
show_label=False,
)
# Dynamic header showing matched leopard name
selected_match_header = gr.Markdown(
"", visible=True
)
# Create tabs for each body part
with gr.Tabs():
with gr.Tab("Head"):
head_indicator = gr.Markdown("")
head_empty_message = gr.Markdown(
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
'<p style="font-size: 16px;">No reference images available for this body part</p>'
"</div>",
visible=False,
)
gallery_head = gr.Gallery(
columns=6,
height=400,
object_fit="scale-down",
allow_preview=True,
)
with gr.Tab("Left Flank"):
left_flank_indicator = gr.Markdown("")
left_flank_empty_message = gr.Markdown(
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
'<p style="font-size: 16px;">No reference images available for this body part</p>'
"</div>",
visible=False,
)
gallery_left_flank = gr.Gallery(
columns=6,
height=400,
object_fit="scale-down",
allow_preview=True,
)
with gr.Tab("Right Flank"):
right_flank_indicator = gr.Markdown("")
right_flank_empty_message = gr.Markdown(
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
'<p style="font-size: 16px;">No reference images available for this body part</p>'
"</div>",
visible=False,
)
gallery_right_flank = gr.Gallery(
columns=6,
height=400,
object_fit="scale-down",
allow_preview=True,
)
with gr.Tab("Tail"):
tail_indicator = gr.Markdown("")
tail_empty_message = gr.Markdown(
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
'<p style="font-size: 16px;">No reference images available for this body part</p>'
"</div>",
visible=False,
)
gallery_tail = gr.Gallery(
columns=6,
height=400,
object_fit="scale-down",
allow_preview=True,
)
with gr.Tab("Other"):
misc_indicator = gr.Markdown("")
misc_empty_message = gr.Markdown(
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
'<p style="font-size: 16px;">No reference images available for this body part</p>'
"</div>",
visible=False,
)
gallery_misc = gr.Gallery(
columns=6,
height=400,
object_fit="scale-down",
allow_preview=True,
)
# Connect submit button
submit_btn.click(
fn=lambda img, ext, top_k, locs, parts, ex_path: run_identification(
image=img,
extractor=ext,
top_k=int(top_k),
selected_locations=locs,
selected_body_parts=parts,
example_path=ex_path,
config=config,
),
inputs=[
image_input,
extractor_dropdown,
top_k_input,
location_filter,
body_part_filter,
selected_example_state,
],
outputs=[
# Pipeline outputs (5 total)
result_text,
seg_viz,
cropped_image,
extracted_kpts_viz,
matches_dataset,
# Rank 1 auto-display outputs (18 total)
matched_kpts_viz,
clean_comparison_viz,
selected_match_header,
head_indicator,
left_flank_indicator,
right_flank_indicator,
tail_indicator,
misc_indicator,
head_empty_message,
left_flank_empty_message,
right_flank_empty_message,
tail_empty_message,
misc_empty_message,
gallery_head,
gallery_left_flank,
gallery_right_flank,
gallery_tail,
gallery_misc,
],
)
# Connect dataset selection
matches_dataset.select(
fn=on_match_selected,
outputs=[
matched_kpts_viz,
clean_comparison_viz,
selected_match_header,
head_indicator,
left_flank_indicator,
right_flank_indicator,
tail_indicator,
misc_indicator,
head_empty_message,
left_flank_empty_message,
right_flank_empty_message,
tail_empty_message,
misc_empty_message,
gallery_head,
gallery_left_flank,
gallery_right_flank,
gallery_tail,
gallery_misc,
],
)
# Tab 2: Explore Catalog
with gr.Tab("Explore Catalog"):
gr.Markdown(
"""
## Snow Leopard Catalog Browser
Browse the reference catalog of known snow leopard individuals.
Each individual has multiple reference images from different body parts and locations.
"""
)
# Display catalog statistics
stats = catalog_index.get("statistics", {})
formatted_locations = [
loc.replace("_", " ").title() for loc in stats.get("locations", [])
]
gr.Markdown(
f"""
### Catalog Statistics
- **Total Individuals:** {stats.get("total_individuals", "N/A")}
- **Total Images:** {stats.get("total_reference_images", "N/A")}
- **Locations:** {", ".join(formatted_locations)}
- **Body Parts:** {", ".join(stats.get("body_parts", []))}
"""
)
gr.Markdown("---")
gr.Markdown("### Individual Leopards by Location")
# Group individuals by location
individuals_by_location = {}
for individual_data in individuals_data:
location = individual_data.get("location", "unknown")
if location not in individuals_by_location:
individuals_by_location[location] = []
individuals_by_location[location].append(individual_data)
# Create tabs for each location
with gr.Tabs():
for location in sorted(individuals_by_location.keys()):
with gr.Tab(f"{location.replace('_', ' ').title()}"):
# Create subtabs for each individual in this location
with gr.Tabs():
for leopard_data in individuals_by_location[location]:
create_leopard_tab(
leopard_metadata=leopard_data, config=config
)
# Cleanup on app close
app.unload(cleanup_temp_files)
# Load first example image on startup
def load_first_example():
"""Load the first example image when the app starts.
Returns both the image AND the path so the cache can be used
when the user clicks Identify without selecting a new example.
"""
if example_images:
try:
first_image = Image.open(example_images[0])
first_path = str(example_images[0])
return first_image, first_path
except Exception as e:
logger.error(f"Error loading first example image: {e}")
return None, None
return None, None
app.load(fn=load_first_example, outputs=[image_input, selected_example_state])
return app
if __name__ == "__main__":
# Extract data archives on first run (for HF Spaces deployment)
ensure_data_extracted()
# Ensure SAM model is downloaded
logger.info("Checking for SAM HQ model...")
sam_path = ensure_sam_model()
# Validate required directories exist
if not CATALOG_ROOT.exists():
logger.error(f"Catalog not found: {CATALOG_ROOT}")
logger.error("Please ensure catalog data is present in data/catalog/")
exit(1)
if not EXAMPLES_DIR.exists():
logger.warning(f"Examples directory not found: {EXAMPLES_DIR}")
EXAMPLES_DIR.mkdir(parents=True, exist_ok=True)
# Create config
config = AppConfig(
model_path=None, # Not using YOLO
catalog_root=CATALOG_ROOT,
examples_dir=EXAMPLES_DIR,
top_k=TOP_K_DEFAULT,
port=7860,
share=False,
sam_checkpoint_path=sam_path,
sam_model_type=SAM_MODEL_TYPE,
gdino_model_id=GDINO_MODEL_ID,
text_prompt=TEXT_PROMPT,
)
# Build and launch app
logger.info("Building Gradio interface...")
app = create_app(config)
logger.info("Launching app...")
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)