achouffe commited on
Commit
7870cc2
·
verified ·
1 Parent(s): 1cab1bb

feat: initial gradio app

Browse files
.gitattributes CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ # Images (catalog and examples)
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.JPG filter=lfs diff=lfs merge=lfs -text
5
+ *.JPEG filter=lfs diff=lfs merge=lfs -text
6
+ *.png filter=lfs diff=lfs merge=lfs -text
7
+ *.PNG filter=lfs diff=lfs merge=lfs -text
8
+
9
+ # Archives and models
10
  *.7z filter=lfs diff=lfs merge=lfs -text
11
  *.arrow filter=lfs diff=lfs merge=lfs -text
12
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv/
25
+ venv/
26
+ ENV/
27
+ env/
28
+
29
+ # uv
30
+ uv.lock
31
+
32
+ # IDE
33
+ .idea/
34
+ .vscode/
35
+ *.swp
36
+ *.swo
37
+ *~
38
+
39
+ # Models (downloaded at runtime)
40
+ data/models/
41
+
42
+ # Extracted data (regenerated from archives on first run)
43
+ data/catalog/
44
+ cached_results/
45
+
46
+ # Keep archives tracked (these are the source of truth)
47
+ !data/catalog.tar.gz
48
+ !data/cache.tar.gz
49
+
50
+ # Temp files
51
+ *.tmp
52
+ *.temp
53
+ .DS_Store
54
+ Thumbs.db
55
+
56
+ # Jupyter
57
+ .ipynb_checkpoints/
58
+
59
+ # Logs
60
+ *.log
Makefile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: help install run clean create-archives extract-data precompute-cache archive-info
2
+
3
+ help: ## Show this help message
4
+ @echo "Usage: make [target]"
5
+ @echo ""
6
+ @echo "Available targets:"
7
+ @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
8
+
9
+ install: ## Install dependencies with uv
10
+ uv sync
11
+
12
+ run: ## Run the Gradio app locally
13
+ uv run python app.py
14
+
15
+ clean: ## Clean up temporary files and caches
16
+ rm -rf .venv uv.lock
17
+ find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
18
+ find . -type f -name "*.pyc" -delete
19
+ find . -type f -name "*.pyo" -delete
20
+
21
+ create-archives: ## Create compressed archives from catalog and cache
22
+ uv run python scripts/create_archives.py
23
+
24
+ extract-data: ## Extract archives (done automatically on first run)
25
+ uv run python -c "from snowleopard_reid.data_setup import ensure_data_extracted; ensure_data_extracted()"
26
+
27
+ precompute-cache: ## Run pipeline on all examples to generate cache
28
+ uv run python scripts/precompute_cache.py
29
+
30
+ archive-info: ## Show info about archives and directories
31
+ uv run python scripts/create_archives.py --info
README.md CHANGED
@@ -1,13 +1,56 @@
1
  ---
2
- title: Snowleopard Reid
3
- emoji: 🐠
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.0.1
 
8
  app_file: app.py
9
  pinned: false
10
- short_description: snow leopard reID using AI
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Snowleopard reID
3
+ emoji: 🐆
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
+ python_version: 3.11
9
  app_file: app.py
10
  pinned: false
11
+ short_description: Snow Leopard reID using AI
12
  ---
13
 
14
+ # Snow Leopard Re-Identification
15
+
16
+ AI-powered snow leopard re-identification system using computer vision and deep learning.
17
+
18
+ ## Features
19
+
20
+ - **Single Image Matching**: Upload an image to identify individual snow leopards against a catalog
21
+ - **Batch Processing**: Process multiple images at once with filtering options
22
+ - **Multiple Detection Methods**: YOLO or Grounding DINO + SAM HQ for segmentation
23
+ - **Multiple Matching Algorithms**: SIFT, SuperPoint, DISK, or ALIKED feature extractors
24
+
25
+ ## Local Development
26
+
27
+ ### Prerequisites
28
+
29
+ - Python 3.11+
30
+ - [uv](https://github.com/astral-sh/uv) package manager
31
+
32
+ ### Setup
33
+
34
+ ```bash
35
+ # Install dependencies
36
+ make install
37
+
38
+ # Run the app locally
39
+ make run
40
+ ```
41
+
42
+ The app will be available at `http://localhost:7860`.
43
+
44
+ ## Data
45
+
46
+ - **Catalog**: Pre-computed features for known snow leopards (stored with Git LFS)
47
+ - **SAM HQ Model**: Downloaded automatically at runtime from HuggingFace Hub
48
+
49
+ ## Tech Stack
50
+
51
+ - Gradio for the web interface
52
+ - PyTorch for deep learning
53
+ - Grounding DINO for zero-shot object detection
54
+ - SAM HQ (Segment Anything Model High Quality) for segmentation
55
+ - LightGlue for feature matching
56
+ - Wasserstein distance for match scoring
app.py ADDED
@@ -0,0 +1,1608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio web application for snow leopard identification and catalog exploration.
2
+
3
+ This interactive web interface provides an easy-to-use frontend for the snow
4
+ leopard identification system. Users can upload images, view matches against the catalog,
5
+ and explore reference leopards through a browser-based UI powered by Gradio.
6
+
7
+ Features:
8
+ - Upload snow leopard images or select from examples
9
+ - Run full identification pipeline with GDINO+SAM segmentation
10
+ - View top-K matches with Wasserstein distance scores
11
+ - Explore complete leopard catalog with thumbnails
12
+ - Visualize matched keypoints between query and catalog images
13
+
14
+ Usage:
15
+ # Local testing with uv:
16
+ uv sync
17
+ uv run python app.py
18
+
19
+ # Deployed on Hugging Face Spaces
20
+ """
21
+
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ # Add src to path for imports BEFORE importing snowleopard_reid
26
+ SPACE_ROOT = Path(__file__).parent
27
+ sys.path.insert(0, str(SPACE_ROOT / "src"))
28
+
29
+ import logging
30
+ import shutil
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+
34
+ import cv2
35
+ import gradio as gr
36
+ import numpy as np
37
+ import torch
38
+ import yaml
39
+ from huggingface_hub import hf_hub_download
40
+ from PIL import Image
41
+
42
+ from snowleopard_reid.catalog import (
43
+ get_available_body_parts,
44
+ get_available_locations,
45
+ get_catalog_metadata_for_id,
46
+ load_catalog_index,
47
+ load_leopard_metadata,
48
+ )
49
+ from snowleopard_reid.pipeline.stages import (
50
+ run_feature_extraction_stage,
51
+ run_matching_stage,
52
+ run_preprocess_stage,
53
+ run_segmentation_stage,
54
+ select_best_mask,
55
+ )
56
+ from snowleopard_reid.pipeline.stages.segmentation import (
57
+ load_gdino_model,
58
+ load_sam_predictor,
59
+ )
60
+ from snowleopard_reid.visualization import (
61
+ draw_keypoints_overlay,
62
+ draw_matched_keypoints,
63
+ draw_side_by_side_comparison,
64
+ )
65
+ from snowleopard_reid.cache import (
66
+ filter_cached_matches,
67
+ generate_visualizations_from_npz,
68
+ is_cached,
69
+ load_cached_results,
70
+ )
71
+ from snowleopard_reid.data_setup import ensure_data_extracted
72
+
73
+ # Configure logging
74
+ logging.basicConfig(
75
+ level=logging.INFO,
76
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
77
+ )
78
+ logger = logging.getLogger(__name__)
79
+
80
+ # Configuration (hardcoded for HF Spaces / local dev)
81
+ CATALOG_ROOT = SPACE_ROOT / "data" / "catalog"
82
+ SAM_CHECKPOINT_DIR = SPACE_ROOT / "data" / "models"
83
+ SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth"
84
+ EXAMPLES_DIR = SPACE_ROOT / "data" / "examples"
85
+ GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
86
+ TEXT_PROMPT = "a snow leopard."
87
+ TOP_K_DEFAULT = 5
88
+ SAM_MODEL_TYPE = "vit_l"
89
+
90
+
91
+ @dataclass
92
+ class AppConfig:
93
+ """Configuration for the Snow Leopard ID UI application."""
94
+
95
+ model_path: Path | None
96
+ catalog_root: Path
97
+ examples_dir: Path
98
+ top_k: int
99
+ port: int
100
+ share: bool
101
+ # GDINO+SAM parameters
102
+ sam_checkpoint_path: Path
103
+ sam_model_type: str
104
+ gdino_model_id: str
105
+ text_prompt: str
106
+
107
+
108
+ def ensure_sam_model() -> Path:
109
+ """Download SAM HQ model if not present.
110
+
111
+ Returns:
112
+ Path to the SAM HQ checkpoint file
113
+ """
114
+ sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME
115
+ if not sam_path.exists():
116
+ logger.info("Downloading SAM HQ model (1.6GB)...")
117
+ SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
118
+ hf_hub_download(
119
+ repo_id="lkeab/hq-sam",
120
+ filename=SAM_CHECKPOINT_NAME,
121
+ local_dir=SAM_CHECKPOINT_DIR,
122
+ )
123
+ logger.info("SAM HQ model downloaded successfully")
124
+ return sam_path
125
+
126
+
127
+ def get_available_extractors(catalog_root: Path) -> list[str]:
128
+ """Get list of available feature extractors from catalog.
129
+
130
+ Args:
131
+ catalog_root: Root directory of the leopard catalog
132
+
133
+ Returns:
134
+ List of available extractor names (e.g., ['sift', 'superpoint'])
135
+ """
136
+ try:
137
+ catalog_index = load_catalog_index(catalog_root)
138
+ extractors = list(catalog_index.get("feature_extractors", {}).keys())
139
+ if not extractors:
140
+ logger.warning(f"No extractors found in catalog at {catalog_root}")
141
+ return ["sift"] # Default fallback
142
+ return extractors
143
+ except Exception as e:
144
+ logger.error(f"Failed to load catalog index: {e}")
145
+ return ["sift"] # Default fallback
146
+
147
+
148
+ # Global state for models and catalog (loaded at startup)
149
+ LOADED_MODELS = {}
150
+
151
+
152
+ def load_catalog_data(config: AppConfig):
153
+ """Load catalog index and individual leopard metadata.
154
+
155
+ Args:
156
+ config: Application configuration containing catalog_root
157
+
158
+ Returns:
159
+ Tuple of (catalog_index, individuals_data)
160
+ """
161
+ catalog_index_path = config.catalog_root / "catalog_index.yaml"
162
+
163
+ # Load catalog index
164
+ with open(catalog_index_path) as f:
165
+ catalog_index = yaml.safe_load(f)
166
+
167
+ # Load metadata for each individual
168
+ individuals_data = []
169
+ for individual in catalog_index["individuals"]:
170
+ metadata_path = config.catalog_root / individual["metadata_path"]
171
+ with open(metadata_path) as f:
172
+ leopard_metadata = yaml.safe_load(f)
173
+ individuals_data.append(leopard_metadata)
174
+
175
+ return catalog_index, individuals_data
176
+
177
+
178
+ def initialize_models(config: AppConfig):
179
+ """Load models at startup for faster inference.
180
+
181
+ Args:
182
+ config: Application configuration containing model paths
183
+ """
184
+ logger.info("Initializing models...")
185
+
186
+ # Check for GPU
187
+ device = "cuda" if torch.cuda.is_available() else "cpu"
188
+ logger.info(f"Using device: {device}")
189
+
190
+ if device == "cuda":
191
+ gpu_name = torch.cuda.get_device_name(0)
192
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
193
+ logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")
194
+
195
+ # Load Grounding DINO model
196
+ logger.info(f"Loading Grounding DINO model: {config.gdino_model_id}")
197
+ gdino_processor, gdino_model = load_gdino_model(
198
+ model_id=config.gdino_model_id,
199
+ device=device,
200
+ )
201
+ LOADED_MODELS["gdino_processor"] = gdino_processor
202
+ LOADED_MODELS["gdino_model"] = gdino_model
203
+ logger.info("Grounding DINO model loaded successfully")
204
+
205
+ # Load SAM HQ model
206
+ logger.info(
207
+ f"Loading SAM HQ model from {config.sam_checkpoint_path} (type: {config.sam_model_type})"
208
+ )
209
+ sam_predictor = load_sam_predictor(
210
+ checkpoint_path=config.sam_checkpoint_path,
211
+ model_type=config.sam_model_type,
212
+ device=device,
213
+ )
214
+ LOADED_MODELS["sam_predictor"] = sam_predictor
215
+ logger.info("SAM HQ model loaded successfully")
216
+
217
+ # Store device info and catalog root for callbacks
218
+ LOADED_MODELS["device"] = device
219
+ LOADED_MODELS["catalog_root"] = config.catalog_root
220
+ LOADED_MODELS["text_prompt"] = config.text_prompt
221
+
222
+ logger.info("Models initialized successfully")
223
+
224
+
225
+ def _load_from_cache(
226
+ example_path: str,
227
+ extractor: str,
228
+ config: "AppConfig",
229
+ filter_locations: list[str] | None = None,
230
+ filter_body_parts: list[str] | None = None,
231
+ top_k: int = 5,
232
+ ):
233
+ """Load cached pipeline results with optional filtering and return UI component updates.
234
+
235
+ Supports the v2.0 cache format which stores ALL matches with location/body_part
236
+ metadata, enabling client-side filtering without re-running the pipeline.
237
+
238
+ Args:
239
+ example_path: Path to the example image
240
+ extractor: Feature extractor name
241
+ config: Application configuration
242
+ filter_locations: Optional list of locations to filter by
243
+ filter_body_parts: Optional list of body parts to filter by
244
+ top_k: Number of top matches to return after filtering
245
+
246
+ Returns:
247
+ Tuple of 23 UI components matching run_identification output
248
+ """
249
+ # Load cached results
250
+ cached = load_cached_results(example_path, extractor)
251
+ predictions = cached["predictions"]
252
+
253
+ # Support both v1.0 ("matches") and v2.0 ("all_matches") cache formats
254
+ if "all_matches" in predictions:
255
+ all_matches = predictions["all_matches"]
256
+ else:
257
+ # Fallback for v1.0 cache format (no filtering support)
258
+ all_matches = predictions.get("matches", [])
259
+
260
+ # Filter and re-rank matches
261
+ matches = filter_cached_matches(
262
+ all_matches=all_matches,
263
+ filter_locations=filter_locations,
264
+ filter_body_parts=filter_body_parts,
265
+ top_k=top_k,
266
+ )
267
+
268
+ if not matches:
269
+ # No matches after filtering - return empty results
270
+ return (
271
+ "No matches found with the selected filters",
272
+ cached["segmentation_image"],
273
+ cached["cropped_image"],
274
+ cached["keypoints_image"],
275
+ [],
276
+ gr.update(value=None),
277
+ gr.update(value=None),
278
+ gr.update(value=""),
279
+ gr.update(value=""),
280
+ gr.update(value=""),
281
+ gr.update(value=""),
282
+ gr.update(value=""),
283
+ gr.update(value=""),
284
+ gr.update(visible=False),
285
+ gr.update(visible=False),
286
+ gr.update(visible=False),
287
+ gr.update(visible=False),
288
+ gr.update(visible=False),
289
+ gr.update(value=[]),
290
+ gr.update(value=[]),
291
+ gr.update(value=[]),
292
+ gr.update(value=[]),
293
+ gr.update(value=[]),
294
+ )
295
+
296
+ # Generate visualizations on-demand from NPZ data
297
+ logger.info(f"Generating visualizations for {len(matches)} filtered matches...")
298
+ match_visualizations, clean_comparison_visualizations = (
299
+ generate_visualizations_from_npz(
300
+ pairwise_dir=cached["pairwise_dir"],
301
+ matches=matches,
302
+ cropped_image_path=cached["pairwise_dir"].parent / "cropped.png",
303
+ )
304
+ )
305
+
306
+ # Store in global state for match selection
307
+ LOADED_MODELS["current_match_visualizations"] = match_visualizations
308
+ LOADED_MODELS["current_clean_comparison_visualizations"] = (
309
+ clean_comparison_visualizations
310
+ )
311
+ LOADED_MODELS["current_enriched_matches"] = matches
312
+ LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
313
+ LOADED_MODELS["current_temp_dir"] = None # No temp dir for cached results
314
+
315
+ # Top match info for result text
316
+ top_match = matches[0]
317
+ top_leopard_name = top_match["leopard_name"]
318
+ top_wasserstein = top_match["wasserstein"]
319
+
320
+ # Determine confidence level
321
+ if top_wasserstein >= 0.12:
322
+ confidence_indicator = "🔵" # Excellent
323
+ elif top_wasserstein >= 0.07:
324
+ confidence_indicator = "🟢" # Good
325
+ elif top_wasserstein >= 0.04:
326
+ confidence_indicator = "🟡" # Fair
327
+ else:
328
+ confidence_indicator = "🔴" # Uncertain
329
+
330
+ result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
331
+
332
+ # Build dataset for top-K matches table
333
+ dataset_samples = []
334
+ for match in matches:
335
+ rank = match["rank"]
336
+ leopard_name = match["leopard_name"]
337
+ wasserstein = match["wasserstein"]
338
+
339
+ # Use location from cache (v2.0) or extract from path
340
+ location = match.get("location", "unknown")
341
+ if location == "unknown":
342
+ catalog_id = match["catalog_id"]
343
+ catalog_metadata = get_catalog_metadata_for_id(
344
+ config.catalog_root, catalog_id
345
+ )
346
+ if catalog_metadata:
347
+ img_path_parts = Path(catalog_metadata["image_path"]).parts
348
+ try:
349
+ db_idx = img_path_parts.index("database")
350
+ if db_idx + 1 < len(img_path_parts):
351
+ location = img_path_parts[db_idx + 1]
352
+ except ValueError:
353
+ pass
354
+
355
+ # Confidence indicator
356
+ if wasserstein >= 0.12:
357
+ indicator = "🔵"
358
+ elif wasserstein >= 0.07:
359
+ indicator = "🟢"
360
+ elif wasserstein >= 0.04:
361
+ indicator = "🟡"
362
+ else:
363
+ indicator = "🔴"
364
+
365
+ dataset_samples.append(
366
+ [
367
+ rank,
368
+ indicator,
369
+ leopard_name.title(),
370
+ location.replace("_", " ").title(),
371
+ f"{wasserstein:.4f}",
372
+ ]
373
+ )
374
+
375
+ # Load rank 1 details
376
+ rank1_details = load_match_details_for_rank(rank=1)
377
+
378
+ # Return all 23 outputs
379
+ return (
380
+ result_text, # 1. Top match result text
381
+ cached["segmentation_image"], # 2. Segmentation overlay
382
+ cached["cropped_image"], # 3. Cropped leopard
383
+ cached["keypoints_image"], # 4. Extracted keypoints
384
+ dataset_samples, # 5. Matches table data
385
+ *rank1_details, # 6-23. visualizations, header, indicators, galleries
386
+ )
387
+
388
+
389
+ def run_identification(
390
+ image,
391
+ extractor: str,
392
+ top_k: int,
393
+ selected_locations: list[str],
394
+ selected_body_parts: list[str],
395
+ example_path: str | None,
396
+ config: AppConfig,
397
+ ):
398
+ """Run snow leopard identification pipeline on uploaded image.
399
+
400
+ Args:
401
+ image: PIL Image from Gradio upload
402
+ extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked')
403
+ top_k: Number of top matches to return
404
+ selected_locations: List of selected locations (includes "all" for no filtering)
405
+ selected_body_parts: List of selected body parts (includes "all" for no filtering)
406
+ example_path: Path to example image if selected from examples (for cache lookup)
407
+ config: Application configuration
408
+
409
+ Returns:
410
+ Tuple of UI components to update
411
+ """
412
+ if image is None:
413
+ # Return 23 empty outputs (5 pipeline + 18 rank 1 details)
414
+ return (
415
+ "Please upload an image first", # 1. result_text
416
+ None, # 2. seg_viz
417
+ None, # 3. cropped_image
418
+ None, # 4. extracted_kpts_viz
419
+ [], # 5. dataset_samples
420
+ gr.update(value=None), # 6. matched_kpts_viz
421
+ gr.update(value=None), # 7. clean_comparison_viz
422
+ gr.update(value=""), # 8. header
423
+ gr.update(value=""), # 9. head indicator
424
+ gr.update(value=""), # 10. left_flank indicator
425
+ gr.update(value=""), # 11. right_flank indicator
426
+ gr.update(value=""), # 12. tail indicator
427
+ gr.update(value=""), # 13. misc indicator
428
+ gr.update(visible=False), # 14. head empty message
429
+ gr.update(visible=False), # 15. left_flank empty message
430
+ gr.update(visible=False), # 16. right_flank empty message
431
+ gr.update(visible=False), # 17. tail empty message
432
+ gr.update(visible=False), # 18. misc empty message
433
+ gr.update(value=[]), # 19. head gallery
434
+ gr.update(value=[]), # 20. left_flank gallery
435
+ gr.update(value=[]), # 21. right_flank gallery
436
+ gr.update(value=[]), # 22. tail gallery
437
+ gr.update(value=[]), # 23. misc gallery
438
+ )
439
+
440
+ # Convert filter selections to None if "all" is selected
441
+ filter_locations = (
442
+ None
443
+ if not selected_locations or "all" in selected_locations
444
+ else selected_locations
445
+ )
446
+ filter_body_parts_parsed = (
447
+ None
448
+ if not selected_body_parts or "all" in selected_body_parts
449
+ else selected_body_parts
450
+ )
451
+
452
+ # Check cache for example images (v2.0 cache supports filtering)
453
+ if example_path and is_cached(example_path, extractor):
454
+ logger.info(f"Cache hit for {example_path} with {extractor}")
455
+ if filter_locations or filter_body_parts_parsed:
456
+ logger.info(f" Applying filters: locations={filter_locations}, body_parts={filter_body_parts_parsed}")
457
+ try:
458
+ return _load_from_cache(
459
+ example_path,
460
+ extractor,
461
+ config,
462
+ filter_locations=filter_locations,
463
+ filter_body_parts=filter_body_parts_parsed,
464
+ top_k=int(top_k),
465
+ )
466
+ except Exception as e:
467
+ logger.warning(f"Cache load failed, running pipeline: {e}")
468
+ # Fall through to run full pipeline
469
+
470
+ # Use the already-parsed filter values for the pipeline
471
+ filter_body_parts = filter_body_parts_parsed
472
+
473
+ # Log applied filters
474
+ if filter_locations or filter_body_parts:
475
+ filter_desc = []
476
+ if filter_locations:
477
+ filter_desc.append(f"locations: {', '.join(filter_locations)}")
478
+ if filter_body_parts:
479
+ filter_desc.append(f"body parts: {', '.join(filter_body_parts)}")
480
+ logger.info(f"Applied filters - {' | '.join(filter_desc)}")
481
+ else:
482
+ logger.info("No filters applied - matching against entire catalog")
483
+
484
+ try:
485
+ # Create temporary directory for this query
486
+ temp_dir = Path(tempfile.mkdtemp(prefix="snowleopard_id_"))
487
+ temp_image_path = temp_dir / "query.jpg"
488
+
489
+ # Save uploaded image
490
+ logger.info(f"Image type: {type(image)}")
491
+ logger.info(f"Image mode: {image.mode if hasattr(image, 'mode') else 'N/A'}")
492
+ logger.info(f"Image size: {image.size if hasattr(image, 'size') else 'N/A'}")
493
+ image.save(temp_image_path, quality=95)
494
+
495
+ # Verify saved image
496
+ saved_size = temp_image_path.stat().st_size
497
+ logger.info(f"Saved image size: {saved_size / 1024 / 1024:.2f} MB")
498
+
499
+ logger.info(f"Processing query image: {temp_image_path}")
500
+
501
+ device = LOADED_MODELS.get("device", "cpu")
502
+
503
+ # Step 1: Run GDINO+SAM segmentation using pre-loaded models
504
+ logger.info("Running GDINO+SAM segmentation...")
505
+ gdino_processor = LOADED_MODELS.get("gdino_processor")
506
+ gdino_model = LOADED_MODELS.get("gdino_model")
507
+ sam_predictor = LOADED_MODELS.get("sam_predictor")
508
+ text_prompt = LOADED_MODELS.get("text_prompt", "a snow leopard.")
509
+
510
+ seg_stage = run_segmentation_stage(
511
+ image_path=temp_image_path,
512
+ strategy="gdino_sam",
513
+ confidence_threshold=0.2,
514
+ device=device,
515
+ gdino_processor=gdino_processor,
516
+ gdino_model=gdino_model,
517
+ sam_predictor=sam_predictor,
518
+ text_prompt=text_prompt,
519
+ box_threshold=0.30,
520
+ text_threshold=0.20,
521
+ )
522
+
523
+ predictions = seg_stage["data"]["predictions"]
524
+ logger.info(f"Number of predictions: {len(predictions)}")
525
+
526
+ if not predictions:
527
+ logger.warning("No predictions found from segmentation")
528
+ logger.warning(f"Full segmentation stage: {seg_stage}")
529
+ # Return 23 empty outputs (5 pipeline + 18 rank 1 details)
530
+ return (
531
+ "No snow leopards detected in image", # 1. result_text
532
+ None, # 2. seg_viz
533
+ None, # 3. cropped_image
534
+ None, # 4. extracted_kpts_viz
535
+ [], # 5. dataset_samples
536
+ gr.update(value=None), # 6. matched_kpts_viz
537
+ gr.update(value=None), # 7. clean_comparison_viz
538
+ gr.update(value=""), # 8. header
539
+ gr.update(value=""), # 9. head indicator
540
+ gr.update(value=""), # 10. left_flank indicator
541
+ gr.update(value=""), # 11. right_flank indicator
542
+ gr.update(value=""), # 12. tail indicator
543
+ gr.update(value=""), # 13. misc indicator
544
+ gr.update(visible=False), # 14. head empty message
545
+ gr.update(visible=False), # 15. left_flank empty message
546
+ gr.update(visible=False), # 16. right_flank empty message
547
+ gr.update(visible=False), # 17. tail empty message
548
+ gr.update(visible=False), # 18. misc empty message
549
+ gr.update(value=[]), # 19. head gallery
550
+ gr.update(value=[]), # 20. left_flank gallery
551
+ gr.update(value=[]), # 21. right_flank gallery
552
+ gr.update(value=[]), # 22. tail gallery
553
+ gr.update(value=[]), # 23. misc gallery
554
+ )
555
+
556
+ # Step 2: Select best mask
557
+ logger.info("Selecting best mask...")
558
+ selected_idx, selected_pred = select_best_mask(
559
+ predictions,
560
+ strategy="confidence_area",
561
+ )
562
+
563
+ # Step 3: Preprocess (crop and mask)
564
+ logger.info("Preprocessing query image...")
565
+ prep_stage = run_preprocess_stage(
566
+ image_path=temp_image_path,
567
+ mask=selected_pred["mask"],
568
+ padding=5,
569
+ )
570
+
571
+ cropped_image_pil = prep_stage["data"]["cropped_image"]
572
+
573
+ # Save cropped image for visualization later
574
+ cropped_path = temp_dir / "cropped.jpg"
575
+ cropped_image_pil.save(cropped_path)
576
+
577
+ # Step 4: Extract features
578
+ logger.info(f"Extracting features using {extractor.upper()}...")
579
+ feat_stage = run_feature_extraction_stage(
580
+ image=cropped_image_pil,
581
+ extractor=extractor,
582
+ max_keypoints=2048,
583
+ device=device,
584
+ )
585
+
586
+ query_features = feat_stage["data"]["features"]
587
+
588
+ # Step 5: Match against catalog
589
+ logger.info("Matching against catalog...")
590
+ pairwise_dir = temp_dir / "pairwise"
591
+ pairwise_dir.mkdir(exist_ok=True)
592
+
593
+ match_stage = run_matching_stage(
594
+ query_features=query_features,
595
+ catalog_path=config.catalog_root,
596
+ top_k=top_k,
597
+ extractor=extractor,
598
+ device=device,
599
+ query_image_path=str(cropped_path),
600
+ pairwise_output_dir=pairwise_dir,
601
+ filter_locations=filter_locations,
602
+ filter_body_parts=filter_body_parts,
603
+ )
604
+
605
+ matches = match_stage["data"]["matches"]
606
+
607
+ if not matches:
608
+ # Return 23 empty outputs (5 pipeline + 18 rank 1 details)
609
+ return (
610
+ "No matches found in catalog", # 1. result_text
611
+ None, # 2. seg_viz
612
+ cropped_image_pil, # 3. cropped_image
613
+ None, # 4. extracted_kpts_viz
614
+ [], # 5. dataset_samples
615
+ gr.update(value=None), # 6. matched_kpts_viz
616
+ gr.update(value=None), # 7. clean_comparison_viz
617
+ gr.update(value=""), # 8. header
618
+ gr.update(value=""), # 9. head indicator
619
+ gr.update(value=""), # 10. left_flank indicator
620
+ gr.update(value=""), # 11. right_flank indicator
621
+ gr.update(value=""), # 12. tail indicator
622
+ gr.update(value=""), # 13. misc indicator
623
+ gr.update(visible=False), # 14. head empty message
624
+ gr.update(visible=False), # 15. left_flank empty message
625
+ gr.update(visible=False), # 16. right_flank empty message
626
+ gr.update(visible=False), # 17. tail empty message
627
+ gr.update(visible=False), # 18. misc empty message
628
+ gr.update(value=[]), # 19. head gallery
629
+ gr.update(value=[]), # 20. left_flank gallery
630
+ gr.update(value=[]), # 21. right_flank gallery
631
+ gr.update(value=[]), # 22. tail gallery
632
+ gr.update(value=[]), # 23. misc gallery
633
+ )
634
+
635
+ # Top match
636
+ top_match = matches[0]
637
+ top_leopard_name = top_match["leopard_name"]
638
+ top_wasserstein = top_match["wasserstein"]
639
+
640
+ # Determine confidence level (higher Wasserstein = better match)
641
+ if top_wasserstein >= 0.12:
642
+ confidence_indicator = "🔵" # Excellent
643
+ elif top_wasserstein >= 0.07:
644
+ confidence_indicator = "🟢" # Good
645
+ elif top_wasserstein >= 0.04:
646
+ confidence_indicator = "🟡" # Fair
647
+ else:
648
+ confidence_indicator = "🔴" # Uncertain
649
+
650
+ result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
651
+
652
+ # Create segmentation visualization
653
+ seg_viz = create_segmentation_viz(
654
+ image_path=temp_image_path, mask=selected_pred["mask"]
655
+ )
656
+
657
+ # Generate extracted keypoints visualization
658
+ extracted_kpts_viz = None
659
+ try:
660
+ # Extract keypoints from query features for visualization
661
+ query_kpts = query_features["keypoints"].cpu().numpy()
662
+ extracted_kpts_viz = draw_keypoints_overlay(
663
+ image_path=cropped_path,
664
+ keypoints=query_kpts,
665
+ max_keypoints=500,
666
+ color="blue",
667
+ ps=10,
668
+ )
669
+ except Exception as e:
670
+ logger.error(f"Error creating extracted keypoints visualization: {e}")
671
+
672
+ # Build dataset for top-K matches table
673
+ dataset_samples = []
674
+ match_visualizations = {}
675
+ clean_comparison_visualizations = {}
676
+
677
+ for match in matches:
678
+ rank = match["rank"]
679
+ leopard_name = match["leopard_name"]
680
+ wasserstein = match["wasserstein"]
681
+ catalog_img_path = Path(match["filepath"])
682
+
683
+ # Get location from catalog metadata
684
+ catalog_id = match["catalog_id"]
685
+ catalog_metadata = get_catalog_metadata_for_id(
686
+ config.catalog_root, catalog_id
687
+ )
688
+ location = "unknown"
689
+ if catalog_metadata:
690
+ # Extract location from path: database/{location}/{individual}/...
691
+ img_path_parts = Path(catalog_metadata["image_path"]).parts
692
+ if len(img_path_parts) >= 3:
693
+ # Find 'database' in path and get next part
694
+ try:
695
+ db_idx = img_path_parts.index("database")
696
+ if db_idx + 1 < len(img_path_parts):
697
+ location = img_path_parts[db_idx + 1]
698
+ except ValueError:
699
+ pass
700
+
701
+ # Confidence indicator (higher Wasserstein = better match)
702
+ if wasserstein >= 0.12:
703
+ indicator = "🔵" # Excellent
704
+ elif wasserstein >= 0.07:
705
+ indicator = "🟢" # Good
706
+ elif wasserstein >= 0.04:
707
+ indicator = "🟡" # Fair
708
+ else:
709
+ indicator = "🔴" # Uncertain
710
+
711
+ # Create visualizations for this match
712
+ npz_path = pairwise_dir / f"rank_{rank:02d}_{match['catalog_id']}.npz"
713
+ if npz_path.exists():
714
+ try:
715
+ pairwise_data = np.load(npz_path)
716
+
717
+ # Create matched keypoints visualization
718
+ match_viz = draw_matched_keypoints(
719
+ query_image_path=cropped_path,
720
+ catalog_image_path=catalog_img_path,
721
+ query_keypoints=pairwise_data["query_keypoints"],
722
+ catalog_keypoints=pairwise_data["catalog_keypoints"],
723
+ match_scores=pairwise_data["match_scores"],
724
+ max_matches=100,
725
+ )
726
+ match_visualizations[rank] = match_viz
727
+
728
+ # Create clean comparison visualization
729
+ clean_viz = draw_side_by_side_comparison(
730
+ query_image_path=cropped_path,
731
+ catalog_image_path=catalog_img_path,
732
+ )
733
+ clean_comparison_visualizations[rank] = clean_viz
734
+ except Exception as e:
735
+ logger.error(f"Error creating visualizations for rank {rank}: {e}")
736
+
737
+ # Format for table (as list, not dict)
738
+ dataset_samples.append(
739
+ [
740
+ rank,
741
+ indicator,
742
+ leopard_name.title(),
743
+ location.replace("_", " ").title(),
744
+ f"{wasserstein:.4f}",
745
+ ]
746
+ )
747
+
748
+ # Store match visualizations, enriched matches, filters, and temp_dir in global state
749
+ LOADED_MODELS["current_match_visualizations"] = match_visualizations
750
+ LOADED_MODELS["current_clean_comparison_visualizations"] = (
751
+ clean_comparison_visualizations
752
+ )
753
+ LOADED_MODELS["current_enriched_matches"] = matches
754
+ LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
755
+ LOADED_MODELS["current_temp_dir"] = temp_dir
756
+
757
+ # Automatically load rank 1 details (visualizations + galleries)
758
+ rank1_details = load_match_details_for_rank(rank=1)
759
+
760
+ # Return 23 outputs total:
761
+ # - 5 pipeline outputs (result_text, seg_viz, cropped_image, extracted_kpts_viz, dataset_samples)
762
+ # - 18 rank 1 details (from load_match_details_for_rank)
763
+ return (
764
+ result_text, # 1. Top match result text
765
+ seg_viz, # 2. Segmentation overlay
766
+ cropped_image_pil, # 3. Cropped leopard
767
+ extracted_kpts_viz, # 4. Extracted keypoints
768
+ dataset_samples, # 5. Matches table data
769
+ # Unpack all 18 rank 1 details:
770
+ *rank1_details, # 6-23. visualizations, header, indicators, galleries
771
+ )
772
+
773
+ except Exception as e:
774
+ logger.error(f"Error processing image: {e}", exc_info=True)
775
+ # Return 23 empty outputs (5 pipeline + 18 rank 1 details)
776
+ return (
777
+ f"Error processing image: {str(e)}", # 1. result_text
778
+ None, # 2. seg_viz
779
+ None, # 3. cropped_image
780
+ None, # 4. extracted_kpts_viz
781
+ [], # 5. dataset_samples
782
+ gr.update(value=None), # 6. matched_kpts_viz
783
+ gr.update(value=None), # 7. clean_comparison_viz
784
+ gr.update(value=""), # 8. header
785
+ gr.update(value=""), # 9. head indicator
786
+ gr.update(value=""), # 10. left_flank indicator
787
+ gr.update(value=""), # 11. right_flank indicator
788
+ gr.update(value=""), # 12. tail indicator
789
+ gr.update(value=""), # 13. misc indicator
790
+ gr.update(visible=False), # 14. head empty message
791
+ gr.update(visible=False), # 15. left_flank empty message
792
+ gr.update(visible=False), # 16. right_flank empty message
793
+ gr.update(visible=False), # 17. tail empty message
794
+ gr.update(visible=False), # 18. misc empty message
795
+ gr.update(value=[]), # 19. head gallery
796
+ gr.update(value=[]), # 20. left_flank gallery
797
+ gr.update(value=[]), # 21. right_flank gallery
798
+ gr.update(value=[]), # 22. tail gallery
799
+ gr.update(value=[]), # 23. misc gallery
800
+ )
801
+
802
+
803
+ def create_segmentation_viz(image_path, mask):
804
+ """Create visualization of segmentation mask overlaid on image."""
805
+ # Load original image
806
+ img = cv2.imread(str(image_path))
807
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
808
+
809
+ # Resize mask to match image dimensions if needed
810
+ if mask.shape[:2] != img_rgb.shape[:2]:
811
+ mask_resized = cv2.resize(
812
+ mask.astype(np.uint8),
813
+ (img_rgb.shape[1], img_rgb.shape[0]),
814
+ interpolation=cv2.INTER_NEAREST,
815
+ )
816
+ else:
817
+ mask_resized = mask
818
+
819
+ # Create colored overlay
820
+ overlay = img_rgb.copy()
821
+ overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region
822
+
823
+ # Blend
824
+ alpha = 0.4
825
+ blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0)
826
+
827
+ return Image.fromarray(blended)
828
+
829
+
830
+ def load_match_details_for_rank(rank: int) -> tuple:
831
+ """Load all match details (visualizations + galleries) for a specific rank.
832
+
833
+ This is a reusable helper function that encapsulates the logic for loading
834
+ match visualizations, galleries, and metadata for a given rank. Used by both
835
+ the automatic rank 1 display after pipeline completion and the interactive
836
+ row selection handler.
837
+
838
+ Args:
839
+ rank: The rank to load (1-indexed)
840
+
841
+ Returns:
842
+ Tuple of 18 Gradio component updates:
843
+ (matched_kpts_viz, clean_comparison_viz, header,
844
+ head_indicator, left_flank_indicator, right_flank_indicator, tail_indicator, misc_indicator,
845
+ head_empty_message, left_flank_empty_message, right_flank_empty_message,
846
+ tail_empty_message, misc_empty_message,
847
+ gallery_head, gallery_left_flank, gallery_right_flank, gallery_tail, gallery_misc)
848
+ """
849
+ # Get stored data from global state
850
+ match_visualizations = LOADED_MODELS.get("current_match_visualizations", {})
851
+ clean_comparison_visualizations = LOADED_MODELS.get(
852
+ "current_clean_comparison_visualizations", {}
853
+ )
854
+ enriched_matches = LOADED_MODELS.get("current_enriched_matches", [])
855
+ filter_body_parts = LOADED_MODELS.get("current_filter_body_parts")
856
+ catalog_root = LOADED_MODELS.get("catalog_root")
857
+
858
+ # Find the match for the requested rank
859
+ selected_match = None
860
+ for match in enriched_matches:
861
+ if match["rank"] == rank:
862
+ selected_match = match
863
+ break
864
+
865
+ if not selected_match or rank not in match_visualizations:
866
+ # Return empty updates for all 18 outputs
867
+ return (
868
+ gr.update(value=None), # 1. matched_kpts_viz
869
+ gr.update(value=None), # 2. clean_comparison_viz
870
+ gr.update(value=""), # 3. header
871
+ gr.update(value=""), # 4. head indicator
872
+ gr.update(value=""), # 5. left_flank indicator
873
+ gr.update(value=""), # 6. right_flank indicator
874
+ gr.update(value=""), # 7. tail indicator
875
+ gr.update(value=""), # 8. misc indicator
876
+ gr.update(visible=False), # 9. head empty message
877
+ gr.update(visible=False), # 10. left_flank empty message
878
+ gr.update(visible=False), # 11. right_flank empty message
879
+ gr.update(visible=False), # 12. tail empty message
880
+ gr.update(visible=False), # 13. misc empty message
881
+ gr.update(value=[]), # 14. head gallery
882
+ gr.update(value=[]), # 15. left_flank gallery
883
+ gr.update(value=[]), # 16. right_flank gallery
884
+ gr.update(value=[]), # 17. tail gallery
885
+ gr.update(value=[]), # 18. misc gallery
886
+ )
887
+
888
+ # Get both visualizations
889
+ match_viz = match_visualizations[rank]
890
+ clean_viz = clean_comparison_visualizations.get(rank)
891
+
892
+ # Create dynamic header with leopard name
893
+ leopard_name = selected_match["leopard_name"]
894
+ header_text = f"## Reference Images for {leopard_name.title()}"
895
+
896
+ # Load galleries organized by body part
897
+ galleries = {}
898
+ if catalog_root:
899
+ try:
900
+ # Extract location from match filepath
901
+ location = None
902
+ filepath = Path(selected_match["filepath"])
903
+ parts = filepath.parts
904
+ if "database" in parts:
905
+ db_idx = parts.index("database")
906
+ if db_idx + 1 < len(parts):
907
+ location = parts[db_idx + 1]
908
+
909
+ galleries = load_matched_individual_gallery_by_body_part(
910
+ catalog_root=catalog_root,
911
+ leopard_name=leopard_name,
912
+ location=location,
913
+ )
914
+ except Exception as e:
915
+ logger.error(f"Error loading gallery for {leopard_name}: {e}")
916
+ # Initialize empty galleries on error
917
+ galleries = {
918
+ "head": [],
919
+ "left_flank": [],
920
+ "right_flank": [],
921
+ "tail": [],
922
+ "misc": [],
923
+ }
924
+
925
+ # Create emoji indicators for filtered body parts
926
+ def get_indicator(body_part: str) -> str:
927
+ """Return star if body part was in filter, empty string otherwise."""
928
+ if filter_body_parts and body_part in filter_body_parts:
929
+ return "* (filtered)"
930
+ return ""
931
+
932
+ # Helper to determine if empty message should be visible
933
+ def is_empty(body_part: str) -> bool:
934
+ """Return True if no images for this body part."""
935
+ return len(galleries.get(body_part, [])) == 0
936
+
937
+ return (
938
+ gr.update(value=match_viz), # 1. matched_kpts_viz
939
+ gr.update(value=clean_viz), # 2. clean_comparison_viz
940
+ gr.update(value=header_text), # 3. header
941
+ gr.update(value=get_indicator("head")), # 4. head indicator
942
+ gr.update(value=get_indicator("left_flank")), # 5. left_flank indicator
943
+ gr.update(value=get_indicator("right_flank")), # 6. right_flank indicator
944
+ gr.update(value=get_indicator("tail")), # 7. tail indicator
945
+ gr.update(value=get_indicator("misc")), # 8. misc indicator
946
+ gr.update(visible=is_empty("head")), # 9. head empty message
947
+ gr.update(visible=is_empty("left_flank")), # 10. left_flank empty message
948
+ gr.update(visible=is_empty("right_flank")), # 11. right_flank empty message
949
+ gr.update(visible=is_empty("tail")), # 12. tail empty message
950
+ gr.update(visible=is_empty("misc")), # 13. misc empty message
951
+ gr.update(
952
+ value=galleries.get("head", []), visible=not is_empty("head")
953
+ ), # 14. head gallery
954
+ gr.update(
955
+ value=galleries.get("left_flank", []), visible=not is_empty("left_flank")
956
+ ), # 15. left_flank gallery
957
+ gr.update(
958
+ value=galleries.get("right_flank", []), visible=not is_empty("right_flank")
959
+ ), # 16. right_flank gallery
960
+ gr.update(
961
+ value=galleries.get("tail", []), visible=not is_empty("tail")
962
+ ), # 17. tail gallery
963
+ gr.update(
964
+ value=galleries.get("misc", []), visible=not is_empty("misc")
965
+ ), # 18. misc gallery
966
+ )
967
+
968
+
969
+ def on_match_selected(evt: gr.SelectData):
970
+ """Handle selection of a match from the dataset table.
971
+
972
+ Returns both visualizations, header, indicators, empty messages,
973
+ and galleries organized by body part.
974
+ """
975
+ # evt.index is [row, col] for Dataframe, we want row
976
+ if isinstance(evt.index, (list, tuple)):
977
+ selected_row = evt.index[0]
978
+ else:
979
+ selected_row = evt.index
980
+
981
+ selected_rank = selected_row + 1 # Ranks are 1-indexed
982
+
983
+ # Delegate to the reusable helper function
984
+ return load_match_details_for_rank(selected_rank)
985
+
986
+
987
+ def load_matched_individual_gallery_by_body_part(
988
+ catalog_root: Path,
989
+ leopard_name: str,
990
+ location: str | None = None,
991
+ ) -> dict[str, list[tuple]]:
992
+ """Load all images for a matched individual organized by body part.
993
+
994
+ Args:
995
+ catalog_root: Path to catalog root directory
996
+ leopard_name: Name of the matched individual (e.g., "karindas")
997
+ location: Geographic location (e.g., "skycrest_valley")
998
+
999
+ Returns:
1000
+ Dict mapping body part to list of (PIL.Image, caption) tuples:
1001
+ {
1002
+ "head": [(img1, caption1), (img2, caption2), ...],
1003
+ "left_flank": [...],
1004
+ "right_flank": [...],
1005
+ "tail": [...],
1006
+ "misc": [...]
1007
+ }
1008
+ """
1009
+ # Initialize dict with all body parts
1010
+ galleries = {
1011
+ "head": [],
1012
+ "left_flank": [],
1013
+ "right_flank": [],
1014
+ "tail": [],
1015
+ "misc": [],
1016
+ }
1017
+
1018
+ # Find metadata path: database/{location}/{individual}/metadata.yaml
1019
+ if location:
1020
+ metadata_path = (
1021
+ catalog_root / "database" / location / leopard_name / "metadata.yaml"
1022
+ )
1023
+ else:
1024
+ # Try to find the individual in any location
1025
+ metadata_path = None
1026
+ database_dir = catalog_root / "database"
1027
+ if database_dir.exists():
1028
+ for loc_dir in database_dir.iterdir():
1029
+ if loc_dir.is_dir():
1030
+ potential_path = loc_dir / leopard_name / "metadata.yaml"
1031
+ if potential_path.exists():
1032
+ metadata_path = potential_path
1033
+ break
1034
+
1035
+ if not metadata_path or not metadata_path.exists():
1036
+ logger.warning(f"Metadata not found for {leopard_name}")
1037
+ return galleries
1038
+
1039
+ try:
1040
+ metadata = load_leopard_metadata(metadata_path)
1041
+
1042
+ # Load all images organized by body part
1043
+ for img_entry in metadata["reference_images"]:
1044
+ body_part = img_entry.get("body_part", "misc")
1045
+
1046
+ # Normalize body_part to match our keys
1047
+ if body_part not in galleries:
1048
+ body_part = "misc" # Default to misc if unknown
1049
+
1050
+ # Load image
1051
+ img_path = catalog_root / "database" / img_entry["path"]
1052
+
1053
+ try:
1054
+ img = Image.open(img_path)
1055
+ # Simple caption: just body part name
1056
+ caption = body_part
1057
+ galleries[body_part].append((img, caption))
1058
+ except Exception as e:
1059
+ logger.error(f"Error loading image {img_path}: {e}")
1060
+
1061
+ except Exception as e:
1062
+ logger.error(f"Error loading metadata for {leopard_name}: {e}")
1063
+
1064
+ return galleries
1065
+
1066
+
1067
+ def cleanup_temp_files():
1068
+ """Clean up temporary files from previous run."""
1069
+ temp_dir = LOADED_MODELS.get("current_temp_dir")
1070
+ if temp_dir and temp_dir.exists():
1071
+ try:
1072
+ shutil.rmtree(temp_dir)
1073
+ logger.info(f"Cleaned up temporary directory: {temp_dir}")
1074
+ except Exception as e:
1075
+ logger.warning(f"Error cleaning up temp directory: {e}")
1076
+
1077
+
1078
+ def create_leopard_tab(leopard_metadata, config: AppConfig):
1079
+ """Create a tab for displaying a single leopard's images.
1080
+
1081
+ Args:
1082
+ leopard_metadata: Metadata dictionary for the leopard individual
1083
+ config: Application configuration
1084
+ """
1085
+ # Support both 'leopard_name' and 'individual_name' keys
1086
+ leopard_name = leopard_metadata.get("leopard_name") or leopard_metadata.get(
1087
+ "individual_name"
1088
+ )
1089
+ location = leopard_metadata.get("location", "unknown")
1090
+ total_images = leopard_metadata["statistics"]["total_reference_images"]
1091
+
1092
+ # Get body parts from statistics
1093
+ body_parts = leopard_metadata["statistics"].get(
1094
+ "body_parts_represented", leopard_metadata["statistics"].get("body_parts", [])
1095
+ )
1096
+ body_parts_str = ", ".join(body_parts) if body_parts else "N/A"
1097
+
1098
+ with gr.Tab(f"{leopard_name}"):
1099
+ # Header with statistics
1100
+ gr.Markdown(
1101
+ f"### {leopard_name.title()}\n"
1102
+ f"**Location:** {location.replace('_', ' ').title()} | "
1103
+ f"**{total_images} images** | "
1104
+ f"**Body parts:** {body_parts_str}"
1105
+ )
1106
+
1107
+ # Load all images with body_part captions
1108
+ gallery_data = []
1109
+ for img_entry in leopard_metadata["reference_images"]:
1110
+ img_path = config.catalog_root / "database" / img_entry["path"]
1111
+ body_part = img_entry.get("body_part", "unknown")
1112
+ try:
1113
+ img = Image.open(img_path)
1114
+ # Caption format: just body_part (location is already in tab)
1115
+ caption = body_part
1116
+ gallery_data.append((img, caption))
1117
+ except Exception as e:
1118
+ logger.error(f"Error loading image {img_path}: {e}")
1119
+
1120
+ # Display gallery
1121
+ gr.Gallery(
1122
+ value=gallery_data,
1123
+ label=f"Reference Images for {leopard_name.title()}",
1124
+ columns=6,
1125
+ height=700,
1126
+ object_fit="scale-down",
1127
+ allow_preview=True,
1128
+ )
1129
+
1130
+
1131
+ def create_app(config: AppConfig):
1132
+ """Create and configure the Gradio application.
1133
+
1134
+ Args:
1135
+ config: Application configuration
1136
+ """
1137
+ # Extract data archives on first run (for HF Spaces deployment)
1138
+ ensure_data_extracted()
1139
+
1140
+ # Initialize models at startup
1141
+ initialize_models(config)
1142
+
1143
+ # Load catalog data
1144
+ catalog_index, individuals_data = load_catalog_data(config)
1145
+
1146
+ # Build example images list from examples directory
1147
+ example_images = (
1148
+ list(config.examples_dir.glob("*.jpg"))
1149
+ + list(config.examples_dir.glob("*.JPG"))
1150
+ + list(config.examples_dir.glob("*.png"))
1151
+ )
1152
+ # Sort with Ayima images last
1153
+ example_images.sort(key=lambda x: (1 if "Ayima" in x.name else 0, x.name))
1154
+
1155
+ # Create interface
1156
+ with gr.Blocks(title="Snow Leopard Identification") as app:
1157
+ # Hidden state to track which example image was selected (for cache lookup)
1158
+ selected_example_state = gr.State(value=None)
1159
+
1160
+ gr.HTML("""
1161
+ <div style="text-align: center; margin-bottom: 20px;">
1162
+ <h1 style="margin-bottom: 10px;">Snow Leopard Identification</h1>
1163
+ <p style="font-size: 16px; color: #666;">
1164
+ Computer vision system for identifying individual snow leopards.
1165
+ </p>
1166
+ </div>
1167
+ """)
1168
+
1169
+ # Main tabs
1170
+ with gr.Tabs():
1171
+ # Tab 1: Identify Snow Leopard
1172
+ with gr.Tab("Identify Snow Leopard"):
1173
+ gr.Markdown("""
1174
+ Upload a snow leopard image or select an example to identify which individual it is.
1175
+ The system will detect the leopard, extract distinctive features, and match against the catalog.
1176
+ """)
1177
+
1178
+ with gr.Row():
1179
+ # Left column: Input
1180
+ with gr.Column(scale=1):
1181
+ image_input = gr.Image(
1182
+ type="pil",
1183
+ label="Upload Snow Leopard Image",
1184
+ sources=["upload", "clipboard"],
1185
+ )
1186
+
1187
+ examples_component = gr.Examples(
1188
+ examples=[[str(img)] for img in example_images],
1189
+ inputs=image_input,
1190
+ label="Example Images",
1191
+ )
1192
+
1193
+ # Track example selection for cache lookup
1194
+ def on_example_select(evt: gr.SelectData):
1195
+ """Update state when an example is selected."""
1196
+ if evt.index is not None:
1197
+ return str(example_images[evt.index])
1198
+ return None
1199
+
1200
+ # When image changes, check if it matches an example
1201
+ def check_if_example(img):
1202
+ """Check if uploaded image matches an example path."""
1203
+ # When user uploads a new image, clear the example state
1204
+ # Examples component handles setting state via select event
1205
+ return gr.update() # No change to state on image change
1206
+
1207
+ examples_component.dataset.select(
1208
+ fn=on_example_select,
1209
+ outputs=[selected_example_state],
1210
+ )
1211
+
1212
+ # Clear example state when user uploads a new image
1213
+ image_input.upload(
1214
+ fn=lambda: None,
1215
+ outputs=[selected_example_state],
1216
+ )
1217
+
1218
+ # Location filter dropdown
1219
+ available_locations = get_available_locations(
1220
+ config.catalog_root
1221
+ )
1222
+ location_filter = gr.Dropdown(
1223
+ choices=available_locations,
1224
+ value=["all"],
1225
+ multiselect=True,
1226
+ label="Filter by Location",
1227
+ info="Select locations to search (default: all locations)",
1228
+ )
1229
+
1230
+ # Body part filter dropdown
1231
+ available_body_parts = get_available_body_parts(
1232
+ config.catalog_root
1233
+ )
1234
+ body_part_filter = gr.Dropdown(
1235
+ choices=available_body_parts,
1236
+ value=["all"],
1237
+ multiselect=True,
1238
+ label="Filter by Body Part",
1239
+ info="Select body parts to match (default: all body parts)",
1240
+ )
1241
+
1242
+ # Advanced Configuration Accordion
1243
+ with gr.Accordion("Advanced Configuration", open=False):
1244
+ # Feature extractor dropdown
1245
+ available_extractors = get_available_extractors(
1246
+ config.catalog_root
1247
+ )
1248
+ extractor_dropdown = gr.Dropdown(
1249
+ choices=available_extractors,
1250
+ value="sift"
1251
+ if "sift" in available_extractors
1252
+ else (
1253
+ available_extractors[0]
1254
+ if available_extractors
1255
+ else "sift"
1256
+ ),
1257
+ label="Feature Extractor",
1258
+ info=f"Available: {', '.join(available_extractors)}",
1259
+ scale=1,
1260
+ )
1261
+
1262
+ # Top-K parameter
1263
+ top_k_input = gr.Number(
1264
+ value=config.top_k,
1265
+ label="Top-K Matches",
1266
+ info="Number of top matches to return",
1267
+ minimum=1,
1268
+ maximum=20,
1269
+ step=1,
1270
+ precision=0,
1271
+ scale=1,
1272
+ )
1273
+
1274
+ submit_btn = gr.Button(
1275
+ value="Identify Snow Leopard",
1276
+ variant="primary",
1277
+ size="lg",
1278
+ )
1279
+
1280
+ # Right column: Results
1281
+ with gr.Column(scale=4):
1282
+ # Top-1 prediction
1283
+ result_text = gr.Markdown("")
1284
+
1285
+ # Tabs for different result views
1286
+ with gr.Tabs():
1287
+ with gr.Tab("Model Internals"):
1288
+ gr.Markdown("""
1289
+ View the internal processing steps: segmentation mask, cropped leopard, and extracted keypoints.
1290
+ """)
1291
+ with gr.Row():
1292
+ seg_viz = gr.Image(
1293
+ label="Segmentation Overlay",
1294
+ type="pil",
1295
+ )
1296
+ cropped_image = gr.Image(
1297
+ label="Extracted Snow Leopard",
1298
+ type="pil",
1299
+ )
1300
+ extracted_kpts_viz = gr.Image(
1301
+ label="Extracted Keypoints",
1302
+ type="pil",
1303
+ )
1304
+
1305
+ with gr.Tab("Top Matches"):
1306
+ gr.Markdown("""
1307
+ Click a row to view detailed feature matching visualization and all reference images for that leopard.
1308
+
1309
+ **Higher Wasserstein distance = better match** (typical range: 0.04-0.27)
1310
+
1311
+ **Confidence Levels:** 🔵 Excellent (>=0.12) | 🟢 Good (>=0.07) | 🟡 Fair (>=0.04) | 🔴 Uncertain (<0.04)
1312
+ """)
1313
+
1314
+ matches_dataset = gr.Dataframe(
1315
+ headers=[
1316
+ "Rank",
1317
+ "Confidence",
1318
+ "Leopard Name",
1319
+ "Location",
1320
+ "Wasserstein",
1321
+ ],
1322
+ label="Top Matches",
1323
+ wrap=True,
1324
+ col_count=(5, "fixed"),
1325
+ )
1326
+
1327
+ # Visualization container (always visible, images populated on pipeline completion)
1328
+ with gr.Column() as viz_tabs:
1329
+ # Tabbed visualization views
1330
+ with gr.Tabs():
1331
+ with gr.Tab("Matched Keypoints"):
1332
+ gr.Markdown(
1333
+ "Feature matching with keypoints and confidence-coded connecting lines. "
1334
+ "**Green** = high confidence, **Yellow** = medium, **Red** = low."
1335
+ )
1336
+ matched_kpts_viz = gr.Image(
1337
+ type="pil",
1338
+ show_label=False,
1339
+ )
1340
+
1341
+ with gr.Tab("Clean Comparison"):
1342
+ gr.Markdown(
1343
+ "Side-by-side comparison without feature annotations. "
1344
+ "Useful for assessing overall visual similarity and spotting patterns."
1345
+ )
1346
+ clean_comparison_viz = gr.Image(
1347
+ type="pil",
1348
+ show_label=False,
1349
+ )
1350
+
1351
+ # Dynamic header showing matched leopard name
1352
+ selected_match_header = gr.Markdown("", visible=True)
1353
+
1354
+ # Create tabs for each body part
1355
+ with gr.Tabs():
1356
+ with gr.Tab("Head"):
1357
+ head_indicator = gr.Markdown("")
1358
+ head_empty_message = gr.Markdown(
1359
+ value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
1360
+ '<p style="font-size: 16px;">No reference images available for this body part</p>'
1361
+ "</div>",
1362
+ visible=False,
1363
+ )
1364
+ gallery_head = gr.Gallery(
1365
+ columns=6,
1366
+ height=400,
1367
+ object_fit="scale-down",
1368
+ allow_preview=True,
1369
+ )
1370
+
1371
+ with gr.Tab("Left Flank"):
1372
+ left_flank_indicator = gr.Markdown("")
1373
+ left_flank_empty_message = gr.Markdown(
1374
+ value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
1375
+ '<p style="font-size: 16px;">No reference images available for this body part</p>'
1376
+ "</div>",
1377
+ visible=False,
1378
+ )
1379
+ gallery_left_flank = gr.Gallery(
1380
+ columns=6,
1381
+ height=400,
1382
+ object_fit="scale-down",
1383
+ allow_preview=True,
1384
+ )
1385
+
1386
+ with gr.Tab("Right Flank"):
1387
+ right_flank_indicator = gr.Markdown("")
1388
+ right_flank_empty_message = gr.Markdown(
1389
+ value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
1390
+ '<p style="font-size: 16px;">No reference images available for this body part</p>'
1391
+ "</div>",
1392
+ visible=False,
1393
+ )
1394
+ gallery_right_flank = gr.Gallery(
1395
+ columns=6,
1396
+ height=400,
1397
+ object_fit="scale-down",
1398
+ allow_preview=True,
1399
+ )
1400
+
1401
+ with gr.Tab("Tail"):
1402
+ tail_indicator = gr.Markdown("")
1403
+ tail_empty_message = gr.Markdown(
1404
+ value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
1405
+ '<p style="font-size: 16px;">No reference images available for this body part</p>'
1406
+ "</div>",
1407
+ visible=False,
1408
+ )
1409
+ gallery_tail = gr.Gallery(
1410
+ columns=6,
1411
+ height=400,
1412
+ object_fit="scale-down",
1413
+ allow_preview=True,
1414
+ )
1415
+
1416
+ with gr.Tab("Other"):
1417
+ misc_indicator = gr.Markdown("")
1418
+ misc_empty_message = gr.Markdown(
1419
+ value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
1420
+ '<p style="font-size: 16px;">No reference images available for this body part</p>'
1421
+ "</div>",
1422
+ visible=False,
1423
+ )
1424
+ gallery_misc = gr.Gallery(
1425
+ columns=6,
1426
+ height=400,
1427
+ object_fit="scale-down",
1428
+ allow_preview=True,
1429
+ )
1430
+
1431
+ # Connect submit button
1432
+ submit_btn.click(
1433
+ fn=lambda img, ext, top_k, locs, parts, ex_path: run_identification(
1434
+ image=img,
1435
+ extractor=ext,
1436
+ top_k=int(top_k),
1437
+ selected_locations=locs,
1438
+ selected_body_parts=parts,
1439
+ example_path=ex_path,
1440
+ config=config,
1441
+ ),
1442
+ inputs=[
1443
+ image_input,
1444
+ extractor_dropdown,
1445
+ top_k_input,
1446
+ location_filter,
1447
+ body_part_filter,
1448
+ selected_example_state,
1449
+ ],
1450
+ outputs=[
1451
+ # Pipeline outputs (5 total)
1452
+ result_text,
1453
+ seg_viz,
1454
+ cropped_image,
1455
+ extracted_kpts_viz,
1456
+ matches_dataset,
1457
+ # Rank 1 auto-display outputs (18 total)
1458
+ matched_kpts_viz,
1459
+ clean_comparison_viz,
1460
+ selected_match_header,
1461
+ head_indicator,
1462
+ left_flank_indicator,
1463
+ right_flank_indicator,
1464
+ tail_indicator,
1465
+ misc_indicator,
1466
+ head_empty_message,
1467
+ left_flank_empty_message,
1468
+ right_flank_empty_message,
1469
+ tail_empty_message,
1470
+ misc_empty_message,
1471
+ gallery_head,
1472
+ gallery_left_flank,
1473
+ gallery_right_flank,
1474
+ gallery_tail,
1475
+ gallery_misc,
1476
+ ],
1477
+ )
1478
+
1479
+ # Connect dataset selection
1480
+ matches_dataset.select(
1481
+ fn=on_match_selected,
1482
+ outputs=[
1483
+ matched_kpts_viz,
1484
+ clean_comparison_viz,
1485
+ selected_match_header,
1486
+ head_indicator,
1487
+ left_flank_indicator,
1488
+ right_flank_indicator,
1489
+ tail_indicator,
1490
+ misc_indicator,
1491
+ head_empty_message,
1492
+ left_flank_empty_message,
1493
+ right_flank_empty_message,
1494
+ tail_empty_message,
1495
+ misc_empty_message,
1496
+ gallery_head,
1497
+ gallery_left_flank,
1498
+ gallery_right_flank,
1499
+ gallery_tail,
1500
+ gallery_misc,
1501
+ ],
1502
+ )
1503
+
1504
+ # Tab 2: Explore Catalog
1505
+ with gr.Tab("Explore Catalog"):
1506
+ gr.Markdown(
1507
+ """
1508
+ ## Snow Leopard Catalog Browser
1509
+
1510
+ Browse the reference catalog of known snow leopard individuals.
1511
+ Each individual has multiple reference images from different body parts and locations.
1512
+ """
1513
+ )
1514
+
1515
+ # Display catalog statistics
1516
+ stats = catalog_index.get("statistics", {})
1517
+ formatted_locations = [loc.replace("_", " ").title() for loc in stats.get("locations", [])]
1518
+ gr.Markdown(
1519
+ f"""
1520
+ ### Catalog Statistics
1521
+ - **Total Individuals:** {stats.get("total_individuals", "N/A")}
1522
+ - **Total Images:** {stats.get("total_reference_images", "N/A")}
1523
+ - **Locations:** {", ".join(formatted_locations)}
1524
+ - **Body Parts:** {", ".join(stats.get("body_parts", []))}
1525
+ """
1526
+ )
1527
+
1528
+ gr.Markdown("---")
1529
+ gr.Markdown("### Individual Leopards by Location")
1530
+
1531
+ # Group individuals by location
1532
+ individuals_by_location = {}
1533
+ for individual_data in individuals_data:
1534
+ location = individual_data.get("location", "unknown")
1535
+ if location not in individuals_by_location:
1536
+ individuals_by_location[location] = []
1537
+ individuals_by_location[location].append(individual_data)
1538
+
1539
+ # Create tabs for each location
1540
+ with gr.Tabs():
1541
+ for location in sorted(individuals_by_location.keys()):
1542
+ with gr.Tab(f"{location.replace('_', ' ').title()}"):
1543
+ # Create subtabs for each individual in this location
1544
+ with gr.Tabs():
1545
+ for leopard_data in individuals_by_location[location]:
1546
+ create_leopard_tab(
1547
+ leopard_metadata=leopard_data, config=config
1548
+ )
1549
+
1550
+ # Cleanup on app close
1551
+ app.unload(cleanup_temp_files)
1552
+
1553
+ # Load first example image on startup
1554
+ def load_first_example():
1555
+ """Load the first example image when the app starts."""
1556
+ if example_images:
1557
+ try:
1558
+ first_image = Image.open(example_images[0])
1559
+ return first_image
1560
+ except Exception as e:
1561
+ logger.error(f"Error loading first example image: {e}")
1562
+ return None
1563
+ return None
1564
+
1565
+ app.load(fn=load_first_example, outputs=[image_input])
1566
+
1567
+ return app
1568
+
1569
+
1570
+ if __name__ == "__main__":
1571
+ # Ensure SAM model is downloaded
1572
+ logger.info("Checking for SAM HQ model...")
1573
+ sam_path = ensure_sam_model()
1574
+
1575
+ # Validate required directories exist
1576
+ if not CATALOG_ROOT.exists():
1577
+ logger.error(f"Catalog not found: {CATALOG_ROOT}")
1578
+ logger.error("Please ensure catalog data is present in data/catalog/")
1579
+ exit(1)
1580
+
1581
+ if not EXAMPLES_DIR.exists():
1582
+ logger.warning(f"Examples directory not found: {EXAMPLES_DIR}")
1583
+ EXAMPLES_DIR.mkdir(parents=True, exist_ok=True)
1584
+
1585
+ # Create config
1586
+ config = AppConfig(
1587
+ model_path=None, # Not using YOLO
1588
+ catalog_root=CATALOG_ROOT,
1589
+ examples_dir=EXAMPLES_DIR,
1590
+ top_k=TOP_K_DEFAULT,
1591
+ port=7860,
1592
+ share=False,
1593
+ sam_checkpoint_path=sam_path,
1594
+ sam_model_type=SAM_MODEL_TYPE,
1595
+ gdino_model_id=GDINO_MODEL_ID,
1596
+ text_prompt=TEXT_PROMPT,
1597
+ )
1598
+
1599
+ # Build and launch app
1600
+ logger.info("Building Gradio interface...")
1601
+ app = create_app(config)
1602
+
1603
+ logger.info("Launching app...")
1604
+ app.launch(
1605
+ server_name="0.0.0.0",
1606
+ server_port=7860,
1607
+ share=False,
1608
+ )
data/cache.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d9e2f29be8e2e38d250becadb88625bb598a9a7359e1107a90b48edbfadddcc
3
+ size 213437895
data/catalog.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:034ffb5eec607ebcce6372a725d947f0607866d9d89272e2d20b32604ffcfbe7
3
+ size 276174890
data/examples/07070305 Agim.JPG ADDED

Git LFS Details

  • SHA256: c907df41236bd312a0ef130cd31625c47e7e393843c3ee6938cde14d0551a9da
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
data/examples/08190121 Karindas.JPG ADDED

Git LFS Details

  • SHA256: 9fba592e2b79cd73e4389ffb95099f6b9fb1548e85049ab3609504e7303c1454
  • Pointer size: 132 Bytes
  • Size of remote file: 4.13 MB
data/examples/08190742 Ayima.jpg ADDED

Git LFS Details

  • SHA256: 816824ee1c8e4cf2fa8f8e0c8e07730858e919486f1b1a1181861372e4a182d6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.12 MB
data/examples/09150237 AIKA.JPG ADDED

Git LFS Details

  • SHA256: cf8c7631eb84945cf7b9b436af75afb2e8c29caccfccaee91d1e62b295ac3f47
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
data/examples/IMG_7189 Ayima.JPG ADDED

Git LFS Details

  • SHA256: 3c9623fa08b4b7733b5e3668060dc1aa848d91b1a1689492cb29852f7711f785
  • Pointer size: 131 Bytes
  • Size of remote file: 986 kB
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "snowleopard-reid-gradio"
3
+ version = "0.1.0"
4
+ description = "Snow Leopard Re-Identification Gradio App for Hugging Face Spaces"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "gradio>=5.49.1",
8
+ "torch>=2.0.0",
9
+ "transformers>=4.30.0",
10
+ "timm>=0.9.0",
11
+ "ultralytics>=8.3.78",
12
+ "segment-anything-hq>=0.3.0",
13
+ "lightglue @ git+https://github.com/cvg/LightGlue.git",
14
+ "opencv-python>=4.11.0.86",
15
+ "numpy>=2.0.0",
16
+ "pillow>=11.0.0",
17
+ "pydantic>=2.0.0",
18
+ "pyyaml>=6.0.0",
19
+ "matplotlib>=3.8.0",
20
+ "scipy>=1.10.0",
21
+ "huggingface_hub>=0.20.0",
22
+ ]
23
+
24
+ [tool.hatch.build.targets.wheel]
25
+ packages = ["src/snowleopard_reid"]
26
+
27
+ [tool.hatch.metadata]
28
+ allow-direct-references = true
29
+
30
+ [build-system]
31
+ requires = ["hatchling"]
32
+ build-backend = "hatchling.build"
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ timm>=0.9.0
5
+ ultralytics>=8.3.78
6
+ segment-anything-hq>=0.3.0
7
+
8
+ # Feature matching (git install)
9
+ lightglue @ git+https://github.com/cvg/LightGlue.git
10
+
11
+ # Web UI
12
+ gradio>=5.49.1
13
+
14
+ # Image processing
15
+ opencv-python>=4.11.0.86
16
+ numpy>=2.0.0
17
+ pillow>=11.0.0
18
+ matplotlib>=3.8.0
19
+
20
+ # Utilities
21
+ pydantic>=2.0.0
22
+ pyyaml>=6.0.0
23
+ scipy>=1.10.0
24
+ huggingface_hub>=0.20.0
scripts/create_archives.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Create compressed archives of catalog and cache data.
3
+
4
+ This script packages the catalog and cached_results directories into
5
+ tar.gz archives for efficient storage in Git LFS.
6
+
7
+ Usage:
8
+ # Create both archives
9
+ python scripts/create_archives.py
10
+
11
+ # Create only catalog archive
12
+ python scripts/create_archives.py --catalog-only
13
+
14
+ # Create only cache archive
15
+ python scripts/create_archives.py --cache-only
16
+
17
+ # Show archive info without creating
18
+ python scripts/create_archives.py --info
19
+ """
20
+
21
+ import argparse
22
+ import logging
23
+ import tarfile
24
+ from pathlib import Path
25
+
26
+ # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format="%(asctime)s - %(levelname)s - %(message)s",
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Paths
34
+ PROJECT_ROOT = Path(__file__).parent.parent
35
+ CATALOG_DIR = PROJECT_ROOT / "data" / "catalog"
36
+ CACHE_DIR = PROJECT_ROOT / "cached_results"
37
+ CATALOG_ARCHIVE = PROJECT_ROOT / "data" / "catalog.tar.gz"
38
+ CACHE_ARCHIVE = PROJECT_ROOT / "data" / "cache.tar.gz"
39
+
40
+
41
+ def get_dir_size(path: Path) -> int:
42
+ """Get total size of directory in bytes."""
43
+ return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
44
+
45
+
46
+ def get_file_count(path: Path) -> int:
47
+ """Get total number of files in directory."""
48
+ return sum(1 for f in path.rglob("*") if f.is_file())
49
+
50
+
51
+ def format_size(size_bytes: int) -> str:
52
+ """Format size in human-readable format."""
53
+ for unit in ["B", "KB", "MB", "GB"]:
54
+ if size_bytes < 1024:
55
+ return f"{size_bytes:.1f} {unit}"
56
+ size_bytes /= 1024
57
+ return f"{size_bytes:.1f} TB"
58
+
59
+
60
+ def create_archive(source_dir: Path, archive_path: Path) -> None:
61
+ """Create a tar.gz archive from a directory.
62
+
63
+ Args:
64
+ source_dir: Directory to archive
65
+ archive_path: Output archive path
66
+ """
67
+ if not source_dir.exists():
68
+ logger.error(f"Source directory not found: {source_dir}")
69
+ return
70
+
71
+ source_size = get_dir_size(source_dir)
72
+ file_count = get_file_count(source_dir)
73
+ logger.info(f"Archiving {source_dir.name}/")
74
+ logger.info(f" Source: {format_size(source_size)} ({file_count} files)")
75
+
76
+ # Create archive
77
+ archive_path.parent.mkdir(parents=True, exist_ok=True)
78
+
79
+ with tarfile.open(archive_path, "w:gz") as tar:
80
+ # Add directory with its name as the archive root
81
+ tar.add(source_dir, arcname=source_dir.name)
82
+
83
+ archive_size = archive_path.stat().st_size
84
+ compression_ratio = (1 - archive_size / source_size) * 100 if source_size > 0 else 0
85
+
86
+ logger.info(f" Archive: {format_size(archive_size)}")
87
+ logger.info(f" Compression: {compression_ratio:.1f}% reduction")
88
+ logger.info(f" Created: {archive_path}")
89
+
90
+
91
+ def show_info() -> None:
92
+ """Show information about directories and existing archives."""
93
+ print("\n=== Directory Info ===")
94
+
95
+ for name, path in [("Catalog", CATALOG_DIR), ("Cache", CACHE_DIR)]:
96
+ if path.exists():
97
+ size = get_dir_size(path)
98
+ count = get_file_count(path)
99
+ print(f"{name}: {format_size(size)} ({count} files)")
100
+ else:
101
+ print(f"{name}: not found")
102
+
103
+ print("\n=== Archive Info ===")
104
+
105
+ for name, path in [("Catalog", CATALOG_ARCHIVE), ("Cache", CACHE_ARCHIVE)]:
106
+ if path.exists():
107
+ size = path.stat().st_size
108
+ print(f"{name}: {format_size(size)}")
109
+ else:
110
+ print(f"{name}: not created")
111
+
112
+
113
+ def main():
114
+ parser = argparse.ArgumentParser(
115
+ description="Create compressed archives of catalog and cache data"
116
+ )
117
+ parser.add_argument(
118
+ "--catalog-only",
119
+ action="store_true",
120
+ help="Only create catalog archive",
121
+ )
122
+ parser.add_argument(
123
+ "--cache-only",
124
+ action="store_true",
125
+ help="Only create cache archive",
126
+ )
127
+ parser.add_argument(
128
+ "--info",
129
+ action="store_true",
130
+ help="Show info about directories and archives",
131
+ )
132
+
133
+ args = parser.parse_args()
134
+
135
+ if args.info:
136
+ show_info()
137
+ return
138
+
139
+ # Determine what to archive
140
+ do_catalog = not args.cache_only
141
+ do_cache = not args.catalog_only
142
+
143
+ if do_catalog:
144
+ create_archive(CATALOG_DIR, CATALOG_ARCHIVE)
145
+
146
+ if do_cache:
147
+ create_archive(CACHE_DIR, CACHE_ARCHIVE)
148
+
149
+ print("\n=== Summary ===")
150
+ show_info()
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
scripts/precompute_cache.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Pre-compute pipeline results for all example images.
3
+
4
+ This script runs the full snow leopard identification pipeline on all example images
5
+ with all available feature extractors, caching the results for instant display
6
+ in the Gradio app.
7
+
8
+ Usage:
9
+ # Process all example images with all extractors
10
+ python scripts/precompute_cache.py
11
+
12
+ # Process specific images
13
+ python scripts/precompute_cache.py --images IMG_001.jpg IMG_002.jpg
14
+
15
+ # Process with specific extractors only
16
+ python scripts/precompute_cache.py --extractors sift superpoint
17
+
18
+ # Clear cache and regenerate all
19
+ python scripts/precompute_cache.py --clear
20
+
21
+ # Show cache summary
22
+ python scripts/precompute_cache.py --summary
23
+ """
24
+
25
+ import argparse
26
+ import logging
27
+ import sys
28
+ import tempfile
29
+ from pathlib import Path
30
+
31
+ import cv2
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+
36
+ # Add project root to path for imports
37
+ PROJECT_ROOT = Path(__file__).parent.parent
38
+ sys.path.insert(0, str(PROJECT_ROOT / "src"))
39
+
40
+ from snowleopard_reid.cache import (
41
+ CACHE_DIR,
42
+ clear_cache,
43
+ extract_location_body_part_from_filepath,
44
+ get_cache_dir,
45
+ get_cache_summary,
46
+ )
47
+ from snowleopard_reid.pipeline.stages import (
48
+ run_feature_extraction_stage,
49
+ run_matching_stage,
50
+ run_preprocess_stage,
51
+ run_segmentation_stage,
52
+ select_best_mask,
53
+ )
54
+ from snowleopard_reid.pipeline.stages.segmentation import (
55
+ load_gdino_model,
56
+ load_sam_predictor,
57
+ )
58
+ from snowleopard_reid.visualization import draw_keypoints_overlay
59
+
60
+ # Configure logging
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ format="%(asctime)s - %(levelname)s - %(message)s",
64
+ )
65
+ logger = logging.getLogger(__name__)
66
+
67
+ # Configuration
68
+ CATALOG_ROOT = PROJECT_ROOT / "data" / "catalog"
69
+ SAM_CHECKPOINT_DIR = PROJECT_ROOT / "data" / "models"
70
+ SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth"
71
+ EXAMPLES_DIR = PROJECT_ROOT / "data" / "examples"
72
+ GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
73
+ TEXT_PROMPT = "a snow leopard."
74
+ SAM_MODEL_TYPE = "vit_l"
75
+ # Set very high to get ALL matches (will be limited by catalog size)
76
+ TOP_K_ALL = 1000
77
+ # Default top_k for display
78
+ TOP_K_DEFAULT = 5
79
+
80
+ # All available extractors
81
+ ALL_EXTRACTORS = ["sift", "superpoint", "disk", "aliked"]
82
+
83
+
84
+ def ensure_sam_model() -> Path:
85
+ """Download SAM HQ model if not present.
86
+
87
+ Returns:
88
+ Path to the SAM HQ checkpoint file
89
+ """
90
+ from huggingface_hub import hf_hub_download
91
+
92
+ sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME
93
+ if not sam_path.exists():
94
+ logger.info("Downloading SAM HQ model (1.6GB)...")
95
+ SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
96
+ hf_hub_download(
97
+ repo_id="lkeab/hq-sam",
98
+ filename=SAM_CHECKPOINT_NAME,
99
+ local_dir=SAM_CHECKPOINT_DIR,
100
+ )
101
+ logger.info("SAM HQ model downloaded successfully")
102
+ return sam_path
103
+
104
+
105
+ def create_segmentation_viz(image_path: Path, mask: np.ndarray) -> Image.Image:
106
+ """Create visualization of segmentation mask overlaid on image.
107
+
108
+ Args:
109
+ image_path: Path to original image
110
+ mask: Binary segmentation mask
111
+
112
+ Returns:
113
+ PIL Image with segmentation overlay
114
+ """
115
+ img = cv2.imread(str(image_path))
116
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
117
+
118
+ # Resize mask to match image dimensions if needed
119
+ if mask.shape[:2] != img_rgb.shape[:2]:
120
+ mask_resized = cv2.resize(
121
+ mask.astype(np.uint8),
122
+ (img_rgb.shape[1], img_rgb.shape[0]),
123
+ interpolation=cv2.INTER_NEAREST,
124
+ )
125
+ else:
126
+ mask_resized = mask
127
+
128
+ # Create colored overlay
129
+ overlay = img_rgb.copy()
130
+ overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region
131
+
132
+ # Blend
133
+ alpha = 0.4
134
+ blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0)
135
+
136
+ return Image.fromarray(blended)
137
+
138
+
139
+ def process_and_cache(
140
+ image_path: Path,
141
+ extractor: str,
142
+ gdino_processor,
143
+ gdino_model,
144
+ sam_predictor,
145
+ device: str,
146
+ ) -> bool:
147
+ """Run full pipeline and cache ALL results for one image/extractor combination.
148
+
149
+ This version caches ALL matches (not just top-k) with location/body_part
150
+ metadata, and stores NPZ pairwise data for on-demand visualization generation.
151
+
152
+ Args:
153
+ image_path: Path to example image
154
+ extractor: Feature extractor to use
155
+ gdino_processor: Pre-loaded Grounding DINO processor
156
+ gdino_model: Pre-loaded Grounding DINO model
157
+ sam_predictor: Pre-loaded SAM HQ predictor
158
+ device: Device to run on ('cuda' or 'cpu')
159
+
160
+ Returns:
161
+ True if successful, False otherwise
162
+ """
163
+ logger.info(f"Processing {image_path.name} with {extractor.upper()}...")
164
+
165
+ try:
166
+ # Create temporary directory for intermediate files
167
+ with tempfile.TemporaryDirectory(prefix="snowleopard_cache_") as temp_dir:
168
+ temp_dir = Path(temp_dir)
169
+
170
+ # ================================================================
171
+ # Stage 1: Segmentation (GDINO+SAM)
172
+ # ================================================================
173
+ logger.info(" Running GDINO+SAM segmentation...")
174
+ seg_stage = run_segmentation_stage(
175
+ image_path=image_path,
176
+ strategy="gdino_sam",
177
+ confidence_threshold=0.2,
178
+ device=device,
179
+ gdino_processor=gdino_processor,
180
+ gdino_model=gdino_model,
181
+ sam_predictor=sam_predictor,
182
+ text_prompt=TEXT_PROMPT,
183
+ box_threshold=0.30,
184
+ text_threshold=0.20,
185
+ )
186
+
187
+ predictions = seg_stage["data"]["predictions"]
188
+ if not predictions:
189
+ logger.warning(f" No snow leopards detected in {image_path.name}")
190
+ return False
191
+
192
+ # ================================================================
193
+ # Stage 2: Mask Selection
194
+ # ================================================================
195
+ logger.info(" Selecting best mask...")
196
+ selected_idx, selected_pred = select_best_mask(
197
+ predictions,
198
+ strategy="confidence_area",
199
+ )
200
+
201
+ # Create segmentation visualization
202
+ segmentation_image = create_segmentation_viz(
203
+ image_path=image_path,
204
+ mask=selected_pred["mask"],
205
+ )
206
+
207
+ # ================================================================
208
+ # Stage 3: Preprocessing
209
+ # ================================================================
210
+ logger.info(" Preprocessing...")
211
+ prep_stage = run_preprocess_stage(
212
+ image_path=image_path,
213
+ mask=selected_pred["mask"],
214
+ padding=5,
215
+ )
216
+
217
+ cropped_image = prep_stage["data"]["cropped_image"]
218
+
219
+ # Save cropped image for visualization functions
220
+ cropped_path = temp_dir / "cropped.jpg"
221
+ cropped_image.save(cropped_path)
222
+
223
+ # ================================================================
224
+ # Stage 4: Feature Extraction
225
+ # ================================================================
226
+ logger.info(f" Extracting features ({extractor.upper()})...")
227
+ feat_stage = run_feature_extraction_stage(
228
+ image=cropped_image,
229
+ extractor=extractor,
230
+ max_keypoints=2048,
231
+ device=device,
232
+ )
233
+
234
+ query_features = feat_stage["data"]["features"]
235
+
236
+ # Create keypoints visualization
237
+ query_kpts = query_features["keypoints"].cpu().numpy()
238
+ keypoints_image = draw_keypoints_overlay(
239
+ image_path=cropped_path,
240
+ keypoints=query_kpts,
241
+ max_keypoints=500,
242
+ color="blue",
243
+ ps=10,
244
+ )
245
+
246
+ # ================================================================
247
+ # Stage 5: Matching - Get ALL matches
248
+ # ================================================================
249
+ logger.info(" Matching against catalog (ALL matches)...")
250
+ temp_pairwise_dir = temp_dir / "pairwise"
251
+ temp_pairwise_dir.mkdir(exist_ok=True)
252
+
253
+ match_stage = run_matching_stage(
254
+ query_features=query_features,
255
+ catalog_path=CATALOG_ROOT,
256
+ top_k=TOP_K_ALL, # Get ALL matches
257
+ extractor=extractor,
258
+ device=device,
259
+ query_image_path=str(cropped_path),
260
+ pairwise_output_dir=temp_pairwise_dir,
261
+ )
262
+
263
+ matches = match_stage["data"]["matches"]
264
+
265
+ if not matches:
266
+ logger.warning(f" No matches found for {image_path.name}")
267
+ return False
268
+
269
+ logger.info(f" Found {len(matches)} matches")
270
+
271
+ # ================================================================
272
+ # Enrich matches with location/body_part
273
+ # ================================================================
274
+ logger.info(" Adding location/body_part metadata...")
275
+ for match in matches:
276
+ location, body_part = extract_location_body_part_from_filepath(
277
+ match["filepath"]
278
+ )
279
+ match["location"] = location
280
+ match["body_part"] = body_part
281
+
282
+ # ================================================================
283
+ # Set up cache directory
284
+ # ================================================================
285
+ cache_dir = get_cache_dir(image_path, extractor)
286
+ cache_dir.mkdir(parents=True, exist_ok=True)
287
+ pairwise_dir = cache_dir / "pairwise"
288
+ pairwise_dir.mkdir(exist_ok=True)
289
+
290
+ # ================================================================
291
+ # Copy NPZ files with catalog_id naming (not rank-based)
292
+ # ================================================================
293
+ logger.info(" Copying NPZ pairwise data...")
294
+ npz_count = 0
295
+ for match in matches:
296
+ catalog_id = match["catalog_id"]
297
+ rank = match["rank"]
298
+
299
+ # Source NPZ (rank-based naming from matching stage)
300
+ src_npz = temp_pairwise_dir / f"rank_{rank:02d}_{catalog_id}.npz"
301
+
302
+ # Destination NPZ (catalog_id naming for cache)
303
+ dst_npz = pairwise_dir / f"{catalog_id}.npz"
304
+
305
+ if src_npz.exists():
306
+ import shutil
307
+ shutil.copy2(src_npz, dst_npz)
308
+ npz_count += 1
309
+
310
+ logger.info(f" Copied {npz_count} NPZ files")
311
+
312
+ # ================================================================
313
+ # Build Predictions Dict (v2.0 format with all_matches)
314
+ # ================================================================
315
+ predictions_dict = {
316
+ "format_version": "2.0",
317
+ "query_image": str(image_path),
318
+ "extractor": extractor,
319
+ "pipeline": {
320
+ "segmentation": {
321
+ "strategy": "gdino_sam",
322
+ "num_predictions": len(predictions),
323
+ "selected_idx": selected_idx,
324
+ "confidence": float(selected_pred["confidence"]),
325
+ },
326
+ "preprocessing": {
327
+ "padding": prep_stage["config"]["padding"],
328
+ },
329
+ "features": {
330
+ "num_keypoints": int(feat_stage["metrics"]["num_keypoints"]),
331
+ "extractor": extractor,
332
+ "max_keypoints": 2048,
333
+ },
334
+ "matching": {
335
+ "num_catalog_images": match_stage["metrics"]["num_catalog_images"],
336
+ "num_successful_matches": match_stage["metrics"]["num_successful_matches"],
337
+ },
338
+ },
339
+ "all_matches": matches, # ALL matches with location/body_part
340
+ "top_k": TOP_K_DEFAULT,
341
+ }
342
+
343
+ # ================================================================
344
+ # Save Cache (predictions.json + visualization images)
345
+ # ================================================================
346
+ logger.info(" Saving to cache...")
347
+
348
+ # Save predictions JSON
349
+ import json
350
+ predictions_file = cache_dir / "predictions.json"
351
+ with open(predictions_file, "w") as f:
352
+ json.dump(predictions_dict, f, indent=2)
353
+
354
+ # Save visualization images
355
+ segmentation_image.save(cache_dir / "segmentation.png")
356
+ cropped_image.save(cache_dir / "cropped.png")
357
+ keypoints_image.save(cache_dir / "keypoints.png")
358
+
359
+ # Log cache size
360
+ cache_size = sum(
361
+ f.stat().st_size for f in cache_dir.rglob("*") if f.is_file()
362
+ )
363
+ logger.info(
364
+ f" Cached: {cache_dir.name} ({cache_size / 1024 / 1024:.2f} MB)"
365
+ )
366
+ logger.info(f" {len(matches)} matches, {npz_count} NPZ files")
367
+
368
+ return True
369
+
370
+ except Exception as e:
371
+ logger.error(f" Failed: {e}", exc_info=True)
372
+ return False
373
+
374
+
375
+ def main():
376
+ parser = argparse.ArgumentParser(
377
+ description="Pre-compute pipeline results for example images"
378
+ )
379
+ parser.add_argument(
380
+ "--images",
381
+ nargs="+",
382
+ help="Specific image filenames to process (default: all in examples/)",
383
+ )
384
+ parser.add_argument(
385
+ "--extractors",
386
+ nargs="+",
387
+ choices=ALL_EXTRACTORS,
388
+ default=ALL_EXTRACTORS,
389
+ help="Feature extractors to use (default: all)",
390
+ )
391
+ parser.add_argument(
392
+ "--clear",
393
+ action="store_true",
394
+ help="Clear all cached results before processing",
395
+ )
396
+ parser.add_argument(
397
+ "--summary",
398
+ action="store_true",
399
+ help="Show cache summary and exit",
400
+ )
401
+ parser.add_argument(
402
+ "--device",
403
+ choices=["cpu", "cuda"],
404
+ default=None,
405
+ help="Device to run on (default: auto-detect)",
406
+ )
407
+
408
+ args = parser.parse_args()
409
+
410
+ # Show summary and exit
411
+ if args.summary:
412
+ summary = get_cache_summary()
413
+ print("\n=== Cache Summary ===")
414
+ print(f"Total cached: {summary['total_cached']} items")
415
+ print(f"Total size: {summary['total_size_mb']:.2f} MB")
416
+ print("\nCached items:")
417
+ for item in summary["cached_items"]:
418
+ print(f" - {item['image_stem']} ({item['extractor']}): {item['size_mb']:.2f} MB")
419
+ return
420
+
421
+ # Clear cache if requested
422
+ if args.clear:
423
+ logger.info("Clearing cache...")
424
+ clear_cache()
425
+ logger.info("Cache cleared")
426
+
427
+ # Determine device
428
+ if args.device:
429
+ device = args.device
430
+ else:
431
+ device = "cuda" if torch.cuda.is_available() else "cpu"
432
+
433
+ logger.info(f"Using device: {device}")
434
+ if device == "cuda":
435
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
436
+
437
+ # Find example images
438
+ if args.images:
439
+ image_paths = [EXAMPLES_DIR / img for img in args.images]
440
+ # Filter existing files
441
+ image_paths = [p for p in image_paths if p.exists()]
442
+ if not image_paths:
443
+ logger.error("No valid image paths found")
444
+ sys.exit(1)
445
+ else:
446
+ image_paths = (
447
+ list(EXAMPLES_DIR.glob("*.jpg"))
448
+ + list(EXAMPLES_DIR.glob("*.JPG"))
449
+ + list(EXAMPLES_DIR.glob("*.png"))
450
+ )
451
+
452
+ if not image_paths:
453
+ logger.error(f"No example images found in {EXAMPLES_DIR}")
454
+ sys.exit(1)
455
+
456
+ logger.info(f"Found {len(image_paths)} example images")
457
+ logger.info(f"Extractors: {', '.join(args.extractors)}")
458
+
459
+ # Ensure SAM model is downloaded
460
+ logger.info("Checking for SAM HQ model...")
461
+ sam_path = ensure_sam_model()
462
+
463
+ # Load GDINO model once
464
+ logger.info(f"Loading Grounding DINO model: {GDINO_MODEL_ID}...")
465
+ gdino_processor, gdino_model = load_gdino_model(
466
+ model_id=GDINO_MODEL_ID,
467
+ device=device,
468
+ )
469
+ logger.info("Grounding DINO model loaded")
470
+
471
+ # Load SAM HQ model once
472
+ logger.info(f"Loading SAM HQ model from {sam_path}...")
473
+ sam_predictor = load_sam_predictor(
474
+ checkpoint_path=sam_path,
475
+ model_type=SAM_MODEL_TYPE,
476
+ device=device,
477
+ )
478
+ logger.info("SAM HQ model loaded")
479
+
480
+ # Process all combinations
481
+ total = len(image_paths) * len(args.extractors)
482
+ success = 0
483
+ failed = 0
484
+
485
+ for i, image_path in enumerate(image_paths):
486
+ for j, extractor in enumerate(args.extractors):
487
+ current = i * len(args.extractors) + j + 1
488
+ logger.info(f"\n[{current}/{total}] Processing...")
489
+
490
+ if process_and_cache(
491
+ image_path=image_path,
492
+ extractor=extractor,
493
+ gdino_processor=gdino_processor,
494
+ gdino_model=gdino_model,
495
+ sam_predictor=sam_predictor,
496
+ device=device,
497
+ ):
498
+ success += 1
499
+ else:
500
+ failed += 1
501
+
502
+ # Final summary
503
+ logger.info("\n" + "=" * 50)
504
+ logger.info("PRECOMPUTATION COMPLETE")
505
+ logger.info("=" * 50)
506
+ logger.info(f"Success: {success}/{total}")
507
+ logger.info(f"Failed: {failed}/{total}")
508
+
509
+ # Show cache summary
510
+ summary = get_cache_summary()
511
+ logger.info(f"Total cache size: {summary['total_size_mb']:.2f} MB")
512
+
513
+
514
+ if __name__ == "__main__":
515
+ main()
src/snowleopard_reid/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Snow Leopard Re-Identification Package."""
2
+
3
+ from snowleopard_reid.cache import (
4
+ clear_cache,
5
+ get_cache_summary,
6
+ is_cached,
7
+ load_cached_match_visualizations,
8
+ load_cached_results,
9
+ save_cache_results,
10
+ )
11
+ from snowleopard_reid.images import resize_image_if_needed
12
+ from snowleopard_reid.utils import get_device
13
+
14
+ __all__ = [
15
+ "clear_cache",
16
+ "get_cache_summary",
17
+ "get_device",
18
+ "is_cached",
19
+ "load_cached_match_visualizations",
20
+ "load_cached_results",
21
+ "resize_image_if_needed",
22
+ "save_cache_results",
23
+ ]
24
+
25
+
26
+ def main() -> None:
27
+ print("Hello from snowleopard-reid!")
src/snowleopard_reid/cache.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cache utilities for precomputed pipeline results.
2
+
3
+ This module provides functions for loading and saving cached pipeline results,
4
+ enabling instant display of results for example images without running the
5
+ expensive pipeline (GDINO+SAM segmentation, feature extraction, matching) on CPU.
6
+
7
+ Cache Structure (v2.0 - supports filtering):
8
+ cached_results/
9
+ ├── {image_stem}_{extractor}/
10
+ │ ├── predictions.json # ALL matches with location/body_part
11
+ │ ├── segmentation.png # Segmentation visualization
12
+ │ ├── cropped.png # Cropped snow leopard image
13
+ │ ├── keypoints.png # Extracted keypoints visualization
14
+ │ └── pairwise/
15
+ │ ├── {catalog_id}.npz # NPZ data for ALL matches
16
+ │ └── ... # (visualizations generated on-demand)
17
+ """
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ from PIL import Image
26
+
27
+ from snowleopard_reid.visualization import (
28
+ draw_matched_keypoints,
29
+ draw_side_by_side_comparison,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Cache directory relative to project root
35
+ CACHE_DIR = Path("cached_results")
36
+
37
+
38
+ def get_cache_key(image_path: Path | str, extractor: str) -> str:
39
+ """Generate cache key from image stem and extractor.
40
+
41
+ Args:
42
+ image_path: Path to the query image
43
+ extractor: Feature extractor name (e.g., 'sift', 'superpoint')
44
+
45
+ Returns:
46
+ Cache key string in format "{image_stem}_{extractor}"
47
+ """
48
+ image_path = Path(image_path)
49
+ return f"{image_path.stem}_{extractor}"
50
+
51
+
52
+ def get_cache_dir(image_path: Path | str, extractor: str) -> Path:
53
+ """Get cache directory for an image/extractor combination.
54
+
55
+ Args:
56
+ image_path: Path to the query image
57
+ extractor: Feature extractor name
58
+
59
+ Returns:
60
+ Path to the cache directory
61
+ """
62
+ return CACHE_DIR / get_cache_key(image_path, extractor)
63
+
64
+
65
+ def is_cached(image_path: Path | str, extractor: str) -> bool:
66
+ """Check if results are cached for this image/extractor combination.
67
+
68
+ Args:
69
+ image_path: Path to the query image
70
+ extractor: Feature extractor name
71
+
72
+ Returns:
73
+ True if all required cache files exist
74
+ """
75
+ cache_dir = get_cache_dir(image_path, extractor)
76
+ predictions_file = cache_dir / "predictions.json"
77
+
78
+ if not predictions_file.exists():
79
+ return False
80
+
81
+ # Check for required visualization files
82
+ required_files = [
83
+ "segmentation.png",
84
+ "cropped.png",
85
+ "keypoints.png",
86
+ ]
87
+
88
+ for filename in required_files:
89
+ if not (cache_dir / filename).exists():
90
+ return False
91
+
92
+ return True
93
+
94
+
95
+ def load_cached_results(image_path: Path | str, extractor: str) -> dict:
96
+ """Load all cached results for an image/extractor combination.
97
+
98
+ Args:
99
+ image_path: Path to the query image
100
+ extractor: Feature extractor name
101
+
102
+ Returns:
103
+ Dictionary containing:
104
+ - predictions: Full pipeline predictions dict
105
+ - segmentation_image: PIL Image of segmentation overlay
106
+ - cropped_image: PIL Image of cropped snow leopard
107
+ - keypoints_image: PIL Image of extracted keypoints
108
+ - pairwise_dir: Path to directory with match visualizations
109
+
110
+ Raises:
111
+ FileNotFoundError: If cache files don't exist
112
+ """
113
+ cache_dir = get_cache_dir(image_path, extractor)
114
+
115
+ if not cache_dir.exists():
116
+ raise FileNotFoundError(f"Cache directory not found: {cache_dir}")
117
+
118
+ predictions_file = cache_dir / "predictions.json"
119
+ if not predictions_file.exists():
120
+ raise FileNotFoundError(f"Predictions file not found: {predictions_file}")
121
+
122
+ # Load predictions JSON
123
+ with open(predictions_file) as f:
124
+ predictions = json.load(f)
125
+
126
+ # Load visualization images
127
+ segmentation_image = Image.open(cache_dir / "segmentation.png")
128
+ cropped_image = Image.open(cache_dir / "cropped.png")
129
+ keypoints_image = Image.open(cache_dir / "keypoints.png")
130
+
131
+ return {
132
+ "predictions": predictions,
133
+ "segmentation_image": segmentation_image,
134
+ "cropped_image": cropped_image,
135
+ "keypoints_image": keypoints_image,
136
+ "pairwise_dir": cache_dir / "pairwise",
137
+ }
138
+
139
+
140
+ def load_cached_match_visualizations(
141
+ pairwise_dir: Path,
142
+ matches: list[dict],
143
+ ) -> tuple[dict, dict]:
144
+ """Load cached match and clean comparison visualizations.
145
+
146
+ Args:
147
+ pairwise_dir: Path to pairwise visualizations directory
148
+ matches: List of match dictionaries with rank and catalog_id
149
+
150
+ Returns:
151
+ Tuple of (match_visualizations, clean_comparison_visualizations)
152
+ Both are dicts mapping rank -> PIL Image
153
+ """
154
+ match_visualizations = {}
155
+ clean_comparison_visualizations = {}
156
+
157
+ for match in matches:
158
+ rank = match["rank"]
159
+ catalog_id = match["catalog_id"]
160
+
161
+ # Load match visualization
162
+ match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png"
163
+ if match_path.exists():
164
+ match_visualizations[rank] = Image.open(match_path)
165
+
166
+ # Load clean comparison visualization
167
+ clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png"
168
+ if clean_path.exists():
169
+ clean_comparison_visualizations[rank] = Image.open(clean_path)
170
+
171
+ return match_visualizations, clean_comparison_visualizations
172
+
173
+
174
+ def save_cache_results(
175
+ image_path: Path | str,
176
+ extractor: str,
177
+ predictions: dict,
178
+ segmentation_image: Image.Image,
179
+ cropped_image: Image.Image,
180
+ keypoints_image: Image.Image,
181
+ match_visualizations: dict[int, Image.Image],
182
+ clean_comparison_visualizations: dict[int, Image.Image],
183
+ matches: list[dict],
184
+ ) -> Path:
185
+ """Save pipeline results to cache.
186
+
187
+ Args:
188
+ image_path: Path to the original query image
189
+ extractor: Feature extractor name
190
+ predictions: Full pipeline predictions dictionary
191
+ segmentation_image: PIL Image of segmentation overlay
192
+ cropped_image: PIL Image of cropped snow leopard
193
+ keypoints_image: PIL Image of extracted keypoints
194
+ match_visualizations: Dict mapping rank -> match visualization PIL Image
195
+ clean_comparison_visualizations: Dict mapping rank -> clean comparison PIL Image
196
+ matches: List of match dictionaries with rank and catalog_id
197
+
198
+ Returns:
199
+ Path to the cache directory
200
+ """
201
+ cache_dir = get_cache_dir(image_path, extractor)
202
+ cache_dir.mkdir(parents=True, exist_ok=True)
203
+
204
+ # Save predictions JSON
205
+ predictions_file = cache_dir / "predictions.json"
206
+ with open(predictions_file, "w") as f:
207
+ json.dump(predictions, f, indent=2)
208
+ logger.info(f"Saved predictions: {predictions_file}")
209
+
210
+ # Save visualization images
211
+ segmentation_image.save(cache_dir / "segmentation.png")
212
+ cropped_image.save(cache_dir / "cropped.png")
213
+ keypoints_image.save(cache_dir / "keypoints.png")
214
+ logger.info(f"Saved visualization images to {cache_dir}")
215
+
216
+ # Save pairwise match visualizations
217
+ pairwise_dir = cache_dir / "pairwise"
218
+ pairwise_dir.mkdir(exist_ok=True)
219
+
220
+ for match in matches:
221
+ rank = match["rank"]
222
+ catalog_id = match["catalog_id"]
223
+
224
+ # Save match visualization
225
+ if rank in match_visualizations:
226
+ match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png"
227
+ match_visualizations[rank].save(match_path)
228
+
229
+ # Save clean comparison visualization
230
+ if rank in clean_comparison_visualizations:
231
+ clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png"
232
+ clean_comparison_visualizations[rank].save(clean_path)
233
+
234
+ logger.info(f"Saved {len(match_visualizations)} pairwise visualizations")
235
+
236
+ return cache_dir
237
+
238
+
239
+ def clear_cache(image_path: Path | str = None, extractor: str = None) -> None:
240
+ """Clear cache directory.
241
+
242
+ Args:
243
+ image_path: If provided, only clear cache for this image
244
+ extractor: If provided with image_path, only clear specific cache
245
+ """
246
+ import shutil
247
+
248
+ if image_path and extractor:
249
+ # Clear specific cache
250
+ cache_dir = get_cache_dir(image_path, extractor)
251
+ if cache_dir.exists():
252
+ shutil.rmtree(cache_dir)
253
+ logger.info(f"Cleared cache: {cache_dir}")
254
+ elif CACHE_DIR.exists():
255
+ # Clear all caches
256
+ shutil.rmtree(CACHE_DIR)
257
+ logger.info(f"Cleared all caches: {CACHE_DIR}")
258
+
259
+
260
+ def get_cache_summary() -> dict:
261
+ """Get summary of cached results.
262
+
263
+ Returns:
264
+ Dictionary with cache statistics
265
+ """
266
+ if not CACHE_DIR.exists():
267
+ return {"total_cached": 0, "total_size_mb": 0, "cached_items": []}
268
+
269
+ cached_items = []
270
+ total_size = 0
271
+
272
+ for cache_dir in CACHE_DIR.iterdir():
273
+ if cache_dir.is_dir():
274
+ # Calculate size
275
+ size = sum(f.stat().st_size for f in cache_dir.rglob("*") if f.is_file())
276
+ total_size += size
277
+
278
+ # Parse cache key
279
+ parts = cache_dir.name.rsplit("_", 1)
280
+ if len(parts) == 2:
281
+ image_stem, extractor = parts
282
+ else:
283
+ image_stem, extractor = cache_dir.name, "unknown"
284
+
285
+ cached_items.append({
286
+ "image_stem": image_stem,
287
+ "extractor": extractor,
288
+ "size_mb": size / (1024 * 1024),
289
+ "path": str(cache_dir),
290
+ })
291
+
292
+ return {
293
+ "total_cached": len(cached_items),
294
+ "total_size_mb": total_size / (1024 * 1024),
295
+ "cached_items": cached_items,
296
+ }
297
+
298
+
299
+ def filter_cached_matches(
300
+ all_matches: list[dict],
301
+ filter_locations: list[str] | None = None,
302
+ filter_body_parts: list[str] | None = None,
303
+ top_k: int = 5,
304
+ ) -> list[dict]:
305
+ """Filter cached matches by location/body_part and return top-k.
306
+
307
+ Args:
308
+ all_matches: List of all cached match dictionaries
309
+ filter_locations: List of locations to filter by (e.g., ["skycrest_valley"])
310
+ filter_body_parts: List of body parts to filter by (e.g., ["head", "right_flank"])
311
+ top_k: Number of top matches to return after filtering
312
+
313
+ Returns:
314
+ List of filtered and re-ranked match dictionaries
315
+ """
316
+ # Make a deep copy to avoid modifying the original
317
+ filtered = [copy.deepcopy(m) for m in all_matches]
318
+
319
+ if filter_locations:
320
+ filtered = [m for m in filtered if m.get("location") in filter_locations]
321
+
322
+ if filter_body_parts:
323
+ filtered = [m for m in filtered if m.get("body_part") in filter_body_parts]
324
+
325
+ # Re-sort by wasserstein (descending - higher is better)
326
+ filtered = sorted(filtered, key=lambda x: x.get("wasserstein", 0), reverse=True)
327
+
328
+ # Re-assign ranks for the filtered top-k
329
+ for i, match in enumerate(filtered[:top_k]):
330
+ match["rank"] = i + 1
331
+
332
+ return filtered[:top_k]
333
+
334
+
335
+ def generate_visualizations_from_npz(
336
+ pairwise_dir: Path,
337
+ matches: list[dict],
338
+ cropped_image_path: Path | str,
339
+ ) -> tuple[dict, dict]:
340
+ """Generate match visualizations on-demand from cached NPZ data.
341
+
342
+ Args:
343
+ pairwise_dir: Path to directory containing NPZ pairwise data files
344
+ matches: List of filtered match dictionaries with catalog_id and filepath
345
+ cropped_image_path: Path to the cropped query image
346
+
347
+ Returns:
348
+ Tuple of (match_visualizations, clean_comparison_visualizations)
349
+ Both are dicts mapping rank -> PIL Image
350
+ """
351
+ match_visualizations = {}
352
+ clean_comparison_visualizations = {}
353
+
354
+ cropped_image_path = Path(cropped_image_path)
355
+
356
+ for match in matches:
357
+ rank = match["rank"]
358
+ catalog_id = match["catalog_id"]
359
+ catalog_image_path = Path(match["filepath"])
360
+
361
+ # Look for NPZ file by catalog_id
362
+ npz_path = pairwise_dir / f"{catalog_id}.npz"
363
+
364
+ if npz_path.exists():
365
+ try:
366
+ pairwise_data = np.load(npz_path, allow_pickle=True)
367
+
368
+ # Generate matched keypoints visualization
369
+ match_viz = draw_matched_keypoints(
370
+ query_image_path=cropped_image_path,
371
+ catalog_image_path=catalog_image_path,
372
+ query_keypoints=pairwise_data["query_keypoints"],
373
+ catalog_keypoints=pairwise_data["catalog_keypoints"],
374
+ match_scores=pairwise_data["match_scores"],
375
+ max_matches=100,
376
+ )
377
+ match_visualizations[rank] = match_viz
378
+
379
+ # Generate clean side-by-side comparison
380
+ clean_viz = draw_side_by_side_comparison(
381
+ query_image_path=cropped_image_path,
382
+ catalog_image_path=catalog_image_path,
383
+ )
384
+ clean_comparison_visualizations[rank] = clean_viz
385
+
386
+ except Exception as e:
387
+ logger.warning(
388
+ f"Failed to generate visualization for {catalog_id}: {e}"
389
+ )
390
+ else:
391
+ logger.warning(f"NPZ file not found for {catalog_id}: {npz_path}")
392
+
393
+ return match_visualizations, clean_comparison_visualizations
394
+
395
+
396
+ def extract_location_body_part_from_filepath(filepath: str) -> tuple[str, str]:
397
+ """Extract location and body_part from catalog image filepath.
398
+
399
+ Expected filepath format:
400
+ .../database/{location}/{individual}/images/{body_part}/{filename}
401
+
402
+ Args:
403
+ filepath: Path to catalog image
404
+
405
+ Returns:
406
+ Tuple of (location, body_part)
407
+ """
408
+ parts = Path(filepath).parts
409
+
410
+ # Find "database" in path and extract location (next part) and body_part
411
+ try:
412
+ db_idx = parts.index("database")
413
+ location = parts[db_idx + 1] if db_idx + 1 < len(parts) else "unknown"
414
+
415
+ # Find "images" in path and get body_part (next part)
416
+ img_idx = parts.index("images")
417
+ body_part = parts[img_idx + 1] if img_idx + 1 < len(parts) else "unknown"
418
+
419
+ return location, body_part
420
+ except (ValueError, IndexError):
421
+ return "unknown", "unknown"
src/snowleopard_reid/catalog/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Catalog module for snow leopard re-identification.
2
+
3
+ This module provides utilities for loading and managing the snow leopard catalog,
4
+ including individual metadata and feature data.
5
+ """
6
+
7
+ from .loader import (
8
+ get_all_catalog_features,
9
+ get_available_body_parts,
10
+ get_available_locations,
11
+ get_catalog_metadata_for_id,
12
+ get_filtered_catalog_features,
13
+ load_catalog_index,
14
+ load_leopard_metadata,
15
+ )
16
+
17
+ __all__ = [
18
+ "load_catalog_index",
19
+ "load_leopard_metadata",
20
+ "get_all_catalog_features",
21
+ "get_filtered_catalog_features",
22
+ "get_available_locations",
23
+ "get_available_body_parts",
24
+ "get_catalog_metadata_for_id",
25
+ ]
src/snowleopard_reid/catalog/loader.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for loading and managing the snow leopard catalog.
2
+
3
+ This module provides functions for loading catalog metadata, individual leopard
4
+ information, and catalog features for matching operations.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import yaml
11
+
12
+
13
+ def load_catalog_index(catalog_root: Path) -> dict:
14
+ """Load the catalog index YAML file.
15
+
16
+ Args:
17
+ catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
18
+
19
+ Returns:
20
+ Dictionary with catalog index data including:
21
+ - catalog_version: str
22
+ - feature_extractors: dict
23
+ - individuals: list
24
+ - statistics: dict
25
+
26
+ Raises:
27
+ FileNotFoundError: If catalog index file doesn't exist
28
+ yaml.YAMLError: If YAML parsing fails
29
+ """
30
+ index_path = catalog_root / "catalog_index.yaml"
31
+
32
+ if not index_path.exists():
33
+ raise FileNotFoundError(f"Catalog index not found: {index_path}")
34
+
35
+ try:
36
+ with open(index_path) as f:
37
+ index = yaml.safe_load(f)
38
+ return index
39
+ except yaml.YAMLError as e:
40
+ raise yaml.YAMLError(f"Failed to parse catalog index: {e}")
41
+
42
+
43
+ def load_leopard_metadata(metadata_path: Path) -> dict:
44
+ """Load metadata YAML file for a specific leopard.
45
+
46
+ Args:
47
+ metadata_path: Path to leopard metadata.yaml file
48
+
49
+ Returns:
50
+ Dictionary with leopard metadata including:
51
+ - individual_id: str
52
+ - leopard_name: str
53
+ - reference_images: list
54
+ - statistics: dict
55
+
56
+ Raises:
57
+ FileNotFoundError: If metadata file doesn't exist
58
+ yaml.YAMLError: If YAML parsing fails
59
+ """
60
+ if not metadata_path.exists():
61
+ raise FileNotFoundError(f"Leopard metadata not found: {metadata_path}")
62
+
63
+ try:
64
+ with open(metadata_path) as f:
65
+ metadata = yaml.safe_load(f)
66
+ return metadata
67
+ except yaml.YAMLError as e:
68
+ raise yaml.YAMLError(f"Failed to parse leopard metadata: {e}")
69
+
70
+
71
+ def get_all_catalog_features(
72
+ catalog_root: Path,
73
+ extractor: str = "sift",
74
+ ) -> dict[str, dict[str, torch.Tensor]]:
75
+ """Load all catalog features for a specific extractor.
76
+
77
+ Args:
78
+ catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
79
+ extractor: Feature extractor name (default: 'sift')
80
+
81
+ Returns:
82
+ Dictionary mapping catalog_id to feature dict:
83
+ {
84
+ "leopard1_2022_001": {
85
+ "keypoints": torch.Tensor,
86
+ "descriptors": torch.Tensor,
87
+ "scores": torch.Tensor,
88
+ ...
89
+ },
90
+ ...
91
+ }
92
+
93
+ Raises:
94
+ FileNotFoundError: If catalog doesn't exist
95
+ ValueError: If no features found for extractor
96
+ """
97
+ if not catalog_root.exists():
98
+ raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
99
+
100
+ # Load catalog index to get all individuals
101
+ index = load_catalog_index(catalog_root)
102
+
103
+ # Check if extractor is available
104
+ available_extractors = index.get("feature_extractors", {})
105
+ if extractor not in available_extractors:
106
+ raise ValueError(
107
+ f"Extractor '{extractor}' not available in catalog. "
108
+ f"Available: {list(available_extractors.keys())}"
109
+ )
110
+
111
+ catalog_features = {}
112
+ database_dir = catalog_root / "database"
113
+
114
+ # Load features for each individual
115
+ for individual in index["individuals"]:
116
+ # Support both 'leopard_name' and 'individual_name' keys
117
+ leopard_name = individual.get("leopard_name") or individual.get(
118
+ "individual_name"
119
+ )
120
+ location = individual.get("location", "")
121
+
122
+ # Construct path: database/{location}/{individual_name}/
123
+ if location:
124
+ leopard_dir = database_dir / location / leopard_name
125
+ else:
126
+ leopard_dir = database_dir / leopard_name
127
+
128
+ # Load leopard metadata to get all reference images
129
+ metadata_path = leopard_dir / "metadata.yaml"
130
+ metadata = load_leopard_metadata(metadata_path)
131
+
132
+ # Load features for each reference image
133
+ for ref_image in metadata["reference_images"]:
134
+ # Check if features exist for this extractor
135
+ if extractor not in ref_image.get("features", {}):
136
+ continue
137
+
138
+ # Get feature path (relative to database directory in metadata)
139
+ feature_rel_path = ref_image["features"][extractor]
140
+ feature_path = database_dir / feature_rel_path
141
+
142
+ if not feature_path.exists():
143
+ # Skip missing features with a warning
144
+ continue
145
+
146
+ # Create catalog ID: leopard_name_year_imagenum
147
+ # e.g., "naguima_2022_001"
148
+ image_id = ref_image["image_id"]
149
+ catalog_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
150
+
151
+ # Load features
152
+ try:
153
+ feats = torch.load(feature_path, map_location="cpu", weights_only=False)
154
+ catalog_features[catalog_id] = feats
155
+ except Exception:
156
+ # Skip files that can't be loaded
157
+ continue
158
+
159
+ if not catalog_features:
160
+ raise ValueError(f"No features found for extractor '{extractor}' in catalog")
161
+
162
+ return catalog_features
163
+
164
+
165
+ def get_filtered_catalog_features(
166
+ catalog_root: Path,
167
+ extractor: str = "sift",
168
+ locations: list[str] | None = None,
169
+ body_parts: list[str] | None = None,
170
+ ) -> dict[str, dict[str, torch.Tensor]]:
171
+ """Load filtered catalog features for a specific extractor.
172
+
173
+ Args:
174
+ catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
175
+ extractor: Feature extractor name (default: 'sift')
176
+ locations: List of locations to filter by (e.g., ["skycrest_valley", "silvershadow_highlands"]).
177
+ If None, includes all locations.
178
+ body_parts: List of body parts to filter by (e.g., ["head", "right_flank"]).
179
+ If None, includes all body parts.
180
+
181
+ Returns:
182
+ Dictionary mapping catalog_id to feature dict:
183
+ {
184
+ "leopard1_2022_001": {
185
+ "keypoints": torch.Tensor,
186
+ "descriptors": torch.Tensor,
187
+ "scores": torch.Tensor,
188
+ ...
189
+ },
190
+ ...
191
+ }
192
+
193
+ Raises:
194
+ FileNotFoundError: If catalog doesn't exist
195
+ ValueError: If no features found for extractor or filters
196
+ """
197
+ if not catalog_root.exists():
198
+ raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
199
+
200
+ # Load catalog index to get all individuals
201
+ index = load_catalog_index(catalog_root)
202
+
203
+ # Check if extractor is available
204
+ available_extractors = index.get("feature_extractors", {})
205
+ if extractor not in available_extractors:
206
+ raise ValueError(
207
+ f"Extractor '{extractor}' not available in catalog. "
208
+ f"Available: {list(available_extractors.keys())}"
209
+ )
210
+
211
+ catalog_features = {}
212
+ database_dir = catalog_root / "database"
213
+
214
+ # Load features for each individual
215
+ for individual in index["individuals"]:
216
+ # Support both 'leopard_name' and 'individual_name' keys
217
+ leopard_name = individual.get("leopard_name") or individual.get(
218
+ "individual_name"
219
+ )
220
+ location = individual.get("location", "")
221
+
222
+ # Filter by location if specified
223
+ if locations is not None and location not in locations:
224
+ continue
225
+
226
+ # Construct path: database/{location}/{individual_name}/
227
+ if location:
228
+ leopard_dir = database_dir / location / leopard_name
229
+ else:
230
+ leopard_dir = database_dir / leopard_name
231
+
232
+ # Load leopard metadata to get all reference images
233
+ metadata_path = leopard_dir / "metadata.yaml"
234
+ metadata = load_leopard_metadata(metadata_path)
235
+
236
+ # Load features for each reference image
237
+ for ref_image in metadata["reference_images"]:
238
+ # Filter by body part if specified
239
+ if body_parts is not None:
240
+ ref_body_part = ref_image.get("body_part", "")
241
+ if ref_body_part not in body_parts:
242
+ continue
243
+
244
+ # Check if features exist for this extractor
245
+ if extractor not in ref_image.get("features", {}):
246
+ continue
247
+
248
+ # Get feature path (relative to database directory in metadata)
249
+ feature_rel_path = ref_image["features"][extractor]
250
+ feature_path = database_dir / feature_rel_path
251
+
252
+ if not feature_path.exists():
253
+ # Skip missing features with a warning
254
+ continue
255
+
256
+ # Create catalog ID: leopard_name_year_imagenum
257
+ # e.g., "naguima_2022_001"
258
+ image_id = ref_image["image_id"]
259
+ catalog_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
260
+
261
+ # Load features
262
+ try:
263
+ feats = torch.load(feature_path, map_location="cpu", weights_only=False)
264
+ catalog_features[catalog_id] = feats
265
+ except Exception:
266
+ # Skip files that can't be loaded
267
+ continue
268
+
269
+ if not catalog_features:
270
+ filter_info = []
271
+ if locations:
272
+ filter_info.append(f"locations={locations}")
273
+ if body_parts:
274
+ filter_info.append(f"body_parts={body_parts}")
275
+ filter_str = ", ".join(filter_info) if filter_info else "no filters"
276
+ raise ValueError(
277
+ f"No features found for extractor '{extractor}' with {filter_str}"
278
+ )
279
+
280
+ return catalog_features
281
+
282
+
283
+ def get_available_locations(catalog_root: Path) -> list[str]:
284
+ """Get list of available locations from catalog.
285
+
286
+ Args:
287
+ catalog_root: Path to catalog root directory
288
+
289
+ Returns:
290
+ List of location names prepended with "all" (e.g., ["all", "skycrest_valley", "silvershadow_highlands"])
291
+ """
292
+ try:
293
+ index = load_catalog_index(catalog_root)
294
+ locations = index.get("statistics", {}).get("locations", [])
295
+ return ["all"] + sorted(locations)
296
+ except Exception:
297
+ return ["all"]
298
+
299
+
300
+ def get_available_body_parts(catalog_root: Path) -> list[str]:
301
+ """Get list of available body parts from catalog.
302
+
303
+ Args:
304
+ catalog_root: Path to catalog root directory
305
+
306
+ Returns:
307
+ List of body part names prepended with "all"
308
+ (e.g., ["all", "head", "left_flank", "right_flank", "tail", "misc"])
309
+ """
310
+ try:
311
+ index = load_catalog_index(catalog_root)
312
+ body_parts = index.get("statistics", {}).get("body_parts", [])
313
+ return ["all"] + sorted(body_parts)
314
+ except Exception:
315
+ return ["all"]
316
+
317
+
318
+ def get_catalog_metadata_for_id(
319
+ catalog_root: Path,
320
+ catalog_id: str,
321
+ ) -> dict | None:
322
+ """Get full metadata for a specific catalog ID.
323
+
324
+ Args:
325
+ catalog_root: Path to catalog root directory
326
+ catalog_id: Catalog ID (e.g., "naguima_2022_001")
327
+
328
+ Returns:
329
+ Dictionary with metadata including:
330
+ - leopard_name: str
331
+ - year: int
332
+ - image_path: Path
333
+ - individual_id: str
334
+ Or None if not found
335
+
336
+ Raises:
337
+ FileNotFoundError: If catalog doesn't exist
338
+ """
339
+ if not catalog_root.exists():
340
+ raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
341
+
342
+ # Load catalog index
343
+ index = load_catalog_index(catalog_root)
344
+ database_dir = catalog_root / "database"
345
+
346
+ # Try to find matching individual
347
+ for individual in index["individuals"]:
348
+ # Support both 'leopard_name' and 'individual_name' keys
349
+ leopard_name = individual.get("leopard_name") or individual.get(
350
+ "individual_name"
351
+ )
352
+ location = individual.get("location", "")
353
+
354
+ # Construct path: database/{location}/{individual_name}/
355
+ if location:
356
+ leopard_dir = database_dir / location / leopard_name
357
+ else:
358
+ leopard_dir = database_dir / leopard_name
359
+
360
+ # Load leopard metadata
361
+ metadata_path = leopard_dir / "metadata.yaml"
362
+ metadata = load_leopard_metadata(metadata_path)
363
+
364
+ # Check each reference image
365
+ for ref_image in metadata["reference_images"]:
366
+ # Construct expected catalog ID
367
+ image_id = ref_image["image_id"]
368
+ expected_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
369
+
370
+ if expected_id == catalog_id:
371
+ # Found match
372
+ return {
373
+ "leopard_name": leopard_name,
374
+ "image_path": database_dir / ref_image["path"],
375
+ "individual_id": metadata["individual_id"],
376
+ "filename": ref_image["filename"],
377
+ }
378
+
379
+ return None
src/snowleopard_reid/data_setup.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data setup utilities for extracting archives on first run.
2
+
3
+ This module handles the extraction of catalog and cache archives when the
4
+ application starts. Archives are extracted only once, on first run.
5
+ """
6
+
7
+ import logging
8
+ import tarfile
9
+ from pathlib import Path
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Paths relative to project root
14
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
15
+ DATA_DIR = PROJECT_ROOT / "data"
16
+ CATALOG_ARCHIVE = DATA_DIR / "catalog.tar.gz"
17
+ CATALOG_DIR = DATA_DIR / "catalog"
18
+ CACHE_ARCHIVE = DATA_DIR / "cache.tar.gz"
19
+ CACHE_DIR = PROJECT_ROOT / "cached_results"
20
+
21
+
22
+ def extract_archive(archive_path: Path, extract_to: Path) -> bool:
23
+ """Extract a tar.gz archive to specified directory.
24
+
25
+ Args:
26
+ archive_path: Path to the .tar.gz archive
27
+ extract_to: Directory to extract to (parent of archived dir)
28
+
29
+ Returns:
30
+ True if extraction successful, False otherwise
31
+ """
32
+ if not archive_path.exists():
33
+ logger.debug(f"Archive not found: {archive_path}")
34
+ return False
35
+
36
+ try:
37
+ logger.info(f"Extracting {archive_path.name}...")
38
+ extract_to.mkdir(parents=True, exist_ok=True)
39
+
40
+ with tarfile.open(archive_path, "r:gz") as tar:
41
+ tar.extractall(path=extract_to)
42
+
43
+ logger.info(f"Extracted {archive_path.name} successfully")
44
+ return True
45
+
46
+ except Exception as e:
47
+ logger.error(f"Failed to extract {archive_path.name}: {e}")
48
+ return False
49
+
50
+
51
+ def ensure_data_extracted() -> None:
52
+ """Ensure catalog and cache archives are extracted.
53
+
54
+ Call this function at application startup. It will:
55
+ - Check if catalog directory exists, extract from archive if not
56
+ - Check if cache directory exists, extract from archive if not
57
+
58
+ Archives are only extracted if:
59
+ 1. The archive file exists
60
+ 2. The target directory does not exist
61
+
62
+ This makes the function idempotent - safe to call multiple times.
63
+ """
64
+ # Extract catalog if needed
65
+ if CATALOG_ARCHIVE.exists() and not CATALOG_DIR.exists():
66
+ logger.info("First run detected - extracting catalog data...")
67
+ extract_archive(CATALOG_ARCHIVE, DATA_DIR)
68
+
69
+ # Extract cache if needed
70
+ if CACHE_ARCHIVE.exists() and not CACHE_DIR.exists():
71
+ logger.info("First run detected - extracting cached results...")
72
+ extract_archive(CACHE_ARCHIVE, PROJECT_ROOT)
73
+
74
+ # Log status
75
+ catalog_ready = CATALOG_DIR.exists()
76
+ cache_ready = CACHE_DIR.exists()
77
+
78
+ if catalog_ready and cache_ready:
79
+ logger.debug("All data directories ready")
80
+ else:
81
+ if not catalog_ready:
82
+ logger.warning(f"Catalog not available: {CATALOG_DIR}")
83
+ if not cache_ready:
84
+ logger.warning(f"Cache not available: {CACHE_DIR}")
85
+
86
+
87
+ def is_data_ready() -> bool:
88
+ """Check if all required data directories exist.
89
+
90
+ Returns:
91
+ True if catalog directory exists, False otherwise
92
+ """
93
+ return CATALOG_DIR.exists()
94
+
95
+
96
+ def is_cache_ready() -> bool:
97
+ """Check if cache directory exists.
98
+
99
+ Returns:
100
+ True if cache directory exists, False otherwise
101
+ """
102
+ return CACHE_DIR.exists()
src/snowleopard_reid/features/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Features module for snow leopard re-identification.
2
+
3
+ This module provides utilities for extracting, loading, and saving features
4
+ from snow leopard images using various feature extractors (SIFT, SuperPoint, DISK, ALIKED).
5
+ """
6
+
7
+ from .extraction import (
8
+ extract_aliked_features,
9
+ extract_disk_features,
10
+ extract_features,
11
+ extract_sift_features,
12
+ extract_superpoint_features,
13
+ get_num_keypoints,
14
+ load_features,
15
+ save_features,
16
+ )
17
+
18
+ __all__ = [
19
+ "extract_features",
20
+ "extract_sift_features",
21
+ "extract_superpoint_features",
22
+ "extract_disk_features",
23
+ "extract_aliked_features",
24
+ "load_features",
25
+ "save_features",
26
+ "get_num_keypoints",
27
+ ]
src/snowleopard_reid/features/extraction.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for feature extraction and management.
2
+
3
+ This module provides functions for extracting, loading, and saving image features
4
+ using various feature extractors (SIFT, SuperPoint, DISK, ALIKED).
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from lightglue import ALIKED, DISK, SIFT, SuperPoint
11
+ from lightglue.utils import load_image, rbd
12
+ from PIL import Image
13
+
14
+
15
+ def extract_sift_features(
16
+ image_path: Path | str | Image.Image,
17
+ max_num_keypoints: int = 2048,
18
+ device: str = "cpu",
19
+ ) -> dict[str, torch.Tensor]:
20
+ """Extract SIFT features from an image.
21
+
22
+ Args:
23
+ image_path: Path to image file or PIL Image object
24
+ max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
25
+ device: Device to run extraction on ('cpu' or 'cuda')
26
+
27
+ Returns:
28
+ Dictionary with keys:
29
+ - keypoints: Tensor of shape [N, 2] with (x, y) coordinates
30
+ - descriptors: Tensor of shape [N, 128] with SIFT descriptors
31
+ - scores: Tensor of shape [N] with keypoint scores
32
+ - image_size: Tensor of shape [2] with (width, height)
33
+
34
+ Raises:
35
+ ValueError: If max_num_keypoints is out of valid range
36
+ FileNotFoundError: If image_path is a string/Path and file doesn't exist
37
+ """
38
+ # Validate parameters
39
+ if not (512 <= max_num_keypoints <= 4096):
40
+ raise ValueError(
41
+ f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
42
+ )
43
+
44
+ # Initialize extractor
45
+ extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(device)
46
+
47
+ # Load image
48
+ if isinstance(image_path, (str, Path)):
49
+ image_path = Path(image_path)
50
+ if not image_path.exists():
51
+ raise FileNotFoundError(f"Image not found: {image_path}")
52
+ # load_image returns torch.Tensor [3, H, W]
53
+ image = load_image(str(image_path))
54
+ elif isinstance(image_path, Image.Image):
55
+ # Convert PIL Image to path temporarily
56
+ # For now, require path input for lightglue compatibility
57
+ raise TypeError(
58
+ "PIL Image input not yet supported, please provide path to image file"
59
+ )
60
+ else:
61
+ raise TypeError(
62
+ f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
63
+ )
64
+
65
+ # Move image to device
66
+ image = image.to(device)
67
+
68
+ # Extract features
69
+ with torch.no_grad():
70
+ feats = extractor.extract(image) # auto-resizes image
71
+ feats = rbd(feats) # remove batch dimension
72
+
73
+ # Move features back to CPU for storage
74
+ if device != "cpu":
75
+ feats = {
76
+ k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
77
+ }
78
+
79
+ return feats
80
+
81
+
82
+ def extract_superpoint_features(
83
+ image_path: Path | str | Image.Image,
84
+ max_num_keypoints: int = 2048,
85
+ device: str = "cpu",
86
+ ) -> dict[str, torch.Tensor]:
87
+ """Extract SuperPoint features from an image.
88
+
89
+ Args:
90
+ image_path: Path to image file or PIL Image object
91
+ max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
92
+ device: Device to run extraction on ('cpu' or 'cuda')
93
+
94
+ Returns:
95
+ Dictionary with keys:
96
+ - keypoints: Tensor of shape [N, 2] with (x, y) coordinates
97
+ - descriptors: Tensor of shape [N, 256] with SuperPoint descriptors
98
+ - scores: Tensor of shape [N] with keypoint scores
99
+ - image_size: Tensor of shape [2] with (width, height)
100
+
101
+ Raises:
102
+ ValueError: If max_num_keypoints is out of valid range
103
+ FileNotFoundError: If image_path is a string/Path and file doesn't exist
104
+ """
105
+ # Validate parameters
106
+ if not (512 <= max_num_keypoints <= 4096):
107
+ raise ValueError(
108
+ f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
109
+ )
110
+
111
+ # Initialize extractor
112
+ extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device)
113
+
114
+ # Load image
115
+ if isinstance(image_path, (str, Path)):
116
+ image_path = Path(image_path)
117
+ if not image_path.exists():
118
+ raise FileNotFoundError(f"Image not found: {image_path}")
119
+ image = load_image(str(image_path))
120
+ elif isinstance(image_path, Image.Image):
121
+ raise TypeError(
122
+ "PIL Image input not yet supported, please provide path to image file"
123
+ )
124
+ else:
125
+ raise TypeError(
126
+ f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
127
+ )
128
+
129
+ # Move image to device
130
+ image = image.to(device)
131
+
132
+ # Extract features
133
+ with torch.no_grad():
134
+ feats = extractor.extract(image)
135
+ feats = rbd(feats) # remove batch dimension
136
+
137
+ # Move features back to CPU for storage
138
+ if device != "cpu":
139
+ feats = {
140
+ k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
141
+ }
142
+
143
+ return feats
144
+
145
+
146
+ def extract_disk_features(
147
+ image_path: Path | str | Image.Image,
148
+ max_num_keypoints: int = 2048,
149
+ device: str = "cpu",
150
+ ) -> dict[str, torch.Tensor]:
151
+ """Extract DISK features from an image.
152
+
153
+ Args:
154
+ image_path: Path to image file or PIL Image object
155
+ max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
156
+ device: Device to run extraction on ('cpu' or 'cuda')
157
+
158
+ Returns:
159
+ Dictionary with keys:
160
+ - keypoints: Tensor of shape [N, 2] with (x, y) coordinates
161
+ - descriptors: Tensor of shape [N, 128] with DISK descriptors
162
+ - scores: Tensor of shape [N] with keypoint scores
163
+ - image_size: Tensor of shape [2] with (width, height)
164
+
165
+ Raises:
166
+ ValueError: If max_num_keypoints is out of valid range
167
+ FileNotFoundError: If image_path is a string/Path and file doesn't exist
168
+ """
169
+ # Validate parameters
170
+ if not (512 <= max_num_keypoints <= 4096):
171
+ raise ValueError(
172
+ f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
173
+ )
174
+
175
+ # Initialize extractor
176
+ extractor = DISK(max_num_keypoints=max_num_keypoints).eval().to(device)
177
+
178
+ # Load image
179
+ if isinstance(image_path, (str, Path)):
180
+ image_path = Path(image_path)
181
+ if not image_path.exists():
182
+ raise FileNotFoundError(f"Image not found: {image_path}")
183
+ image = load_image(str(image_path))
184
+ elif isinstance(image_path, Image.Image):
185
+ raise TypeError(
186
+ "PIL Image input not yet supported, please provide path to image file"
187
+ )
188
+ else:
189
+ raise TypeError(
190
+ f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
191
+ )
192
+
193
+ # Move image to device
194
+ image = image.to(device)
195
+
196
+ # Extract features
197
+ with torch.no_grad():
198
+ feats = extractor.extract(image)
199
+ feats = rbd(feats) # remove batch dimension
200
+
201
+ # Move features back to CPU for storage
202
+ if device != "cpu":
203
+ feats = {
204
+ k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
205
+ }
206
+
207
+ return feats
208
+
209
+
210
+ def extract_aliked_features(
211
+ image_path: Path | str | Image.Image,
212
+ max_num_keypoints: int = 2048,
213
+ device: str = "cpu",
214
+ ) -> dict[str, torch.Tensor]:
215
+ """Extract ALIKED features from an image.
216
+
217
+ Args:
218
+ image_path: Path to image file or PIL Image object
219
+ max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
220
+ device: Device to run extraction on ('cpu' or 'cuda')
221
+
222
+ Returns:
223
+ Dictionary with keys:
224
+ - keypoints: Tensor of shape [N, 2] with (x, y) coordinates
225
+ - descriptors: Tensor of shape [N, 128] with ALIKED descriptors
226
+ - scores: Tensor of shape [N] with keypoint scores
227
+ - image_size: Tensor of shape [2] with (width, height)
228
+
229
+ Raises:
230
+ ValueError: If max_num_keypoints is out of valid range
231
+ FileNotFoundError: If image_path is a string/Path and file doesn't exist
232
+ """
233
+ # Validate parameters
234
+ if not (512 <= max_num_keypoints <= 4096):
235
+ raise ValueError(
236
+ f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
237
+ )
238
+
239
+ # Initialize extractor
240
+ extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().to(device)
241
+
242
+ # Load image
243
+ if isinstance(image_path, (str, Path)):
244
+ image_path = Path(image_path)
245
+ if not image_path.exists():
246
+ raise FileNotFoundError(f"Image not found: {image_path}")
247
+ image = load_image(str(image_path))
248
+ elif isinstance(image_path, Image.Image):
249
+ raise TypeError(
250
+ "PIL Image input not yet supported, please provide path to image file"
251
+ )
252
+ else:
253
+ raise TypeError(
254
+ f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
255
+ )
256
+
257
+ # Move image to device
258
+ image = image.to(device)
259
+
260
+ # Extract features
261
+ with torch.no_grad():
262
+ feats = extractor.extract(image)
263
+ feats = rbd(feats) # remove batch dimension
264
+
265
+ # Move features back to CPU for storage
266
+ if device != "cpu":
267
+ feats = {
268
+ k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
269
+ }
270
+
271
+ return feats
272
+
273
+
274
+ def load_features(features_path: Path | str) -> dict[str, torch.Tensor]:
275
+ """Load features from a PyTorch .pt file.
276
+
277
+ Args:
278
+ features_path: Path to .pt file containing features
279
+
280
+ Returns:
281
+ Dictionary with feature tensors (keypoints, descriptors, scores, etc.)
282
+
283
+ Raises:
284
+ FileNotFoundError: If features file doesn't exist
285
+ RuntimeError: If file cannot be loaded
286
+ """
287
+ features_path = Path(features_path)
288
+
289
+ if not features_path.exists():
290
+ raise FileNotFoundError(f"Features file not found: {features_path}")
291
+
292
+ try:
293
+ feats = torch.load(features_path, map_location="cpu", weights_only=False)
294
+ return feats
295
+ except Exception as e:
296
+ raise RuntimeError(f"Failed to load features from {features_path}: {e}")
297
+
298
+
299
+ def save_features(
300
+ features: dict[str, torch.Tensor],
301
+ output_path: Path | str,
302
+ create_dirs: bool = True,
303
+ ) -> None:
304
+ """Save features to a PyTorch .pt file.
305
+
306
+ Args:
307
+ features: Dictionary with feature tensors to save
308
+ output_path: Path where to save .pt file
309
+ create_dirs: Whether to create parent directories if they don't exist
310
+
311
+ Raises:
312
+ ValueError: If features dict is empty or invalid
313
+ OSError: If directory creation or file writing fails
314
+ """
315
+ if not features:
316
+ raise ValueError("Features dictionary is empty")
317
+
318
+ output_path = Path(output_path)
319
+
320
+ # Create parent directories if needed
321
+ if create_dirs and not output_path.parent.exists():
322
+ output_path.parent.mkdir(parents=True, exist_ok=True)
323
+
324
+ try:
325
+ torch.save(features, output_path)
326
+ except Exception as e:
327
+ raise OSError(f"Failed to save features to {output_path}: {e}")
328
+
329
+
330
+ def get_num_keypoints(features: dict[str, torch.Tensor]) -> int:
331
+ """Get the number of keypoints from a features dictionary.
332
+
333
+ Args:
334
+ features: Dictionary with 'keypoints' tensor
335
+
336
+ Returns:
337
+ Number of keypoints (first dimension of keypoints tensor)
338
+
339
+ Raises:
340
+ KeyError: If 'keypoints' key is missing
341
+ """
342
+ if "keypoints" not in features:
343
+ raise KeyError("Features dictionary missing 'keypoints' key")
344
+
345
+ return features["keypoints"].shape[0]
346
+
347
+
348
+ def extract_features(
349
+ extractor: str,
350
+ image_path: Path | str | Image.Image,
351
+ max_num_keypoints: int = 2048,
352
+ device: str = "cpu",
353
+ ) -> dict[str, torch.Tensor]:
354
+ """Extract features from an image using the specified extractor.
355
+
356
+ Factory function that dispatches to the appropriate feature extractor.
357
+
358
+ Args:
359
+ extractor: Feature extractor name ('sift', 'superpoint', 'disk', 'aliked')
360
+ image_path: Path to image file or PIL Image object
361
+ max_num_keypoints: Maximum number of keypoints to extract (default: 2048)
362
+ device: Device to run extraction on ('cpu' or 'cuda')
363
+
364
+ Returns:
365
+ Dictionary with feature tensors (keypoints, descriptors, scores, image_size)
366
+
367
+ Raises:
368
+ ValueError: If extractor name is not supported
369
+ FileNotFoundError: If image_path is a string/Path and file doesn't exist
370
+
371
+ Examples:
372
+ >>> features = extract_features("sift", "image.jpg")
373
+ >>> features = extract_features("sift", "image.jpg", max_num_keypoints=4096, device="cuda")
374
+ """
375
+ extractor = extractor.lower()
376
+
377
+ if extractor == "sift":
378
+ return extract_sift_features(image_path, max_num_keypoints, device)
379
+ elif extractor == "superpoint":
380
+ return extract_superpoint_features(image_path, max_num_keypoints, device)
381
+ elif extractor == "disk":
382
+ return extract_disk_features(image_path, max_num_keypoints, device)
383
+ elif extractor == "aliked":
384
+ return extract_aliked_features(image_path, max_num_keypoints, device)
385
+ else:
386
+ raise ValueError(
387
+ f"Unsupported extractor: {extractor}. Supported extractors: sift, superpoint, disk, aliked"
388
+ )
src/snowleopard_reid/images/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processing utilities for the Snow Leopard Re-ID project."""
2
+
3
+ from snowleopard_reid.images.processing import (
4
+ resize_image_if_needed,
5
+ resize_image_if_needed_cv2,
6
+ resize_image_if_needed_pil,
7
+ )
8
+
9
+ __all__ = [
10
+ "resize_image_if_needed",
11
+ "resize_image_if_needed_cv2",
12
+ "resize_image_if_needed_pil",
13
+ ]
src/snowleopard_reid/images/processing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processing utilities for the Snow Leopard Re-ID project."""
2
+
3
+ import logging
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def resize_image_if_needed_pil(image: Image.Image, max_dim: int = 1024) -> Image.Image:
13
+ """
14
+ Resize PIL Image if either dimension exceeds max_dim, maintaining aspect ratio.
15
+
16
+ Args:
17
+ image: PIL Image to resize
18
+ max_dim: Maximum allowed dimension (default: 1024)
19
+
20
+ Returns:
21
+ Resized image (or original if no resize needed)
22
+ """
23
+ width, height = image.size
24
+
25
+ if width <= max_dim and height <= max_dim:
26
+ return image
27
+
28
+ # Calculate scaling factor
29
+ scale = min(max_dim / width, max_dim / height)
30
+
31
+ # Calculate new dimensions
32
+ new_width = int(width * scale)
33
+ new_height = int(height * scale)
34
+
35
+ # Resize image using high-quality LANCZOS filter
36
+ resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
37
+ logger.debug(f"Resized image from {width}x{height} to {new_width}x{new_height}")
38
+
39
+ return resized
40
+
41
+
42
+ def resize_image_if_needed_cv2(img: np.ndarray, max_dim: int = 1024) -> np.ndarray:
43
+ """
44
+ Resize cv2 image if either dimension exceeds max_dim, maintaining aspect ratio.
45
+
46
+ Args:
47
+ img: Input image as numpy array
48
+ max_dim: Maximum allowed dimension (default: 1024)
49
+
50
+ Returns:
51
+ Resized image (or original if no resize needed)
52
+ """
53
+ height, width = img.shape[:2]
54
+
55
+ if height <= max_dim and width <= max_dim:
56
+ return img
57
+
58
+ # Calculate scaling factor
59
+ scale = min(max_dim / height, max_dim / width)
60
+
61
+ # Calculate new dimensions
62
+ new_width = int(width * scale)
63
+ new_height = int(height * scale)
64
+
65
+ # Resize image using INTER_AREA (best for downscaling)
66
+ resized = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
67
+ logger.debug(f"Resized image from {width}x{height} to {new_width}x{new_height}")
68
+
69
+ return resized
70
+
71
+
72
+ def resize_image_if_needed(
73
+ image: Image.Image | np.ndarray, max_dim: int = 1024
74
+ ) -> Image.Image | np.ndarray:
75
+ """
76
+ Resize image if either dimension exceeds max_dim, maintaining aspect ratio.
77
+
78
+ Automatically detects whether the input is a PIL Image or numpy array
79
+ and applies the appropriate resize function.
80
+
81
+ Args:
82
+ image: PIL Image or numpy array to resize
83
+ max_dim: Maximum allowed dimension (default: 1024)
84
+
85
+ Returns:
86
+ Resized image (or original if no resize needed)
87
+ """
88
+ if isinstance(image, Image.Image):
89
+ return resize_image_if_needed_pil(image=image, max_dim=max_dim)
90
+ elif isinstance(image, np.ndarray):
91
+ return resize_image_if_needed_cv2(img=image, max_dim=max_dim)
92
+ else:
93
+ raise TypeError(f"Expected PIL.Image.Image or np.ndarray, got {type(image)}")
src/snowleopard_reid/masks/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mask processing utilities for snow leopard re-identification."""
2
+
3
+ from snowleopard_reid.masks.processing import (
4
+ add_padding_to_bbox,
5
+ crop_and_mask_image,
6
+ get_mask_bbox,
7
+ )
8
+
9
+ __all__ = [
10
+ "get_mask_bbox",
11
+ "add_padding_to_bbox",
12
+ "crop_and_mask_image",
13
+ ]
src/snowleopard_reid/masks/processing.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for mask processing and image cropping.
2
+
3
+ This module provides functions for working with binary segmentation masks,
4
+ calculating bounding boxes, and cropping images with masks applied.
5
+ """
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ def get_mask_bbox(mask: np.ndarray) -> tuple[int, int, int, int]:
12
+ """Calculate the tight bounding box of a binary mask.
13
+
14
+ Args:
15
+ mask: Binary mask array (0=background, 255=foreground)
16
+
17
+ Returns:
18
+ Tuple of (x_min, y_min, x_max, y_max) in pixel coordinates
19
+
20
+ Raises:
21
+ ValueError: If mask is empty (no foreground pixels)
22
+ """
23
+ # Find all pixels that are part of the mask
24
+ rows = np.any(mask > 0, axis=1)
25
+ cols = np.any(mask > 0, axis=0)
26
+
27
+ if not np.any(rows) or not np.any(cols):
28
+ raise ValueError("Mask is empty (no foreground pixels)")
29
+
30
+ y_min, y_max = np.where(rows)[0][[0, -1]]
31
+ x_min, x_max = np.where(cols)[0][[0, -1]]
32
+
33
+ return int(x_min), int(y_min), int(x_max), int(y_max)
34
+
35
+
36
+ def add_padding_to_bbox(
37
+ bbox: tuple[int, int, int, int],
38
+ padding: int,
39
+ image_width: int,
40
+ image_height: int,
41
+ ) -> tuple[int, int, int, int]:
42
+ """Add padding to a bounding box, clamped to image boundaries.
43
+
44
+ Args:
45
+ bbox: Original bounding box (x_min, y_min, x_max, y_max)
46
+ padding: Padding in pixels to add on all sides
47
+ image_width: Image width for clamping
48
+ image_height: Image height for clamping
49
+
50
+ Returns:
51
+ Padded bounding box (x_min, y_min, x_max, y_max)
52
+ """
53
+ x_min, y_min, x_max, y_max = bbox
54
+
55
+ x_min = max(0, x_min - padding)
56
+ y_min = max(0, y_min - padding)
57
+ x_max = min(image_width - 1, x_max + padding)
58
+ y_max = min(image_height - 1, y_max + padding)
59
+
60
+ return x_min, y_min, x_max, y_max
61
+
62
+
63
+ def crop_and_mask_image(
64
+ image: Image.Image,
65
+ mask: np.ndarray,
66
+ bbox: tuple[int, int, int, int],
67
+ ) -> Image.Image:
68
+ """Crop image to bbox and apply mask with black background.
69
+
70
+ Args:
71
+ image: Original PIL Image
72
+ mask: Binary mask array (same size as image, 0=background, 255=foreground)
73
+ bbox: Bounding box to crop to (x_min, y_min, x_max, y_max)
74
+
75
+ Returns:
76
+ Cropped and masked PIL Image with black background
77
+ """
78
+ x_min, y_min, x_max, y_max = bbox
79
+
80
+ # Crop image and mask to bounding box
81
+ cropped_image = image.crop((x_min, y_min, x_max + 1, y_max + 1))
82
+ cropped_mask = mask[y_min : y_max + 1, x_min : x_max + 1]
83
+
84
+ # Convert image to numpy array
85
+ image_array = np.array(cropped_image)
86
+
87
+ # Create mask with correct shape (add channel dimension if needed)
88
+ if len(image_array.shape) == 3:
89
+ # RGB image - expand mask to 3 channels
90
+ mask_3d = np.repeat(cropped_mask[:, :, np.newaxis] > 0, 3, axis=2)
91
+ else:
92
+ # Grayscale image
93
+ mask_3d = cropped_mask > 0
94
+
95
+ # Apply mask: keep original pixels where mask is True, black elsewhere
96
+ masked_array = np.where(mask_3d, image_array, 0)
97
+
98
+ # Convert back to PIL Image
99
+ return Image.fromarray(masked_array.astype(np.uint8))
src/snowleopard_reid/pipeline/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline module for snow leopard re-identification.
2
+
3
+ This module provides the complete end-to-end pipeline for identifying individual
4
+ snow leopards from query images.
5
+ """
6
+
7
+ from .stages import (
8
+ run_feature_extraction_stage,
9
+ run_mask_selection_stage,
10
+ run_matching_stage,
11
+ run_preprocess_stage,
12
+ run_segmentation_stage,
13
+ select_best_mask,
14
+ )
15
+
16
+ __all__ = [
17
+ "run_segmentation_stage",
18
+ "run_mask_selection_stage",
19
+ "run_preprocess_stage",
20
+ "run_feature_extraction_stage",
21
+ "run_matching_stage",
22
+ "select_best_mask",
23
+ ]
src/snowleopard_reid/pipeline/stages/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline stages for snow leopard re-identification.
2
+
3
+ This module contains all pipeline stages that process query images through
4
+ segmentation, feature extraction, and matching.
5
+ """
6
+
7
+ from .feature_extraction import run_feature_extraction_stage
8
+ from .mask_selection import run_mask_selection_stage, select_best_mask
9
+ from .matching import run_matching_stage
10
+ from .preprocess import run_preprocess_stage
11
+ from .segmentation import run_segmentation_stage
12
+
13
+ __all__ = [
14
+ "run_segmentation_stage",
15
+ "run_mask_selection_stage",
16
+ "select_best_mask",
17
+ "run_preprocess_stage",
18
+ "run_feature_extraction_stage",
19
+ "run_matching_stage",
20
+ ]
src/snowleopard_reid/pipeline/stages/feature_extraction.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature extraction stage for query images.
2
+
3
+ This module extracts features from cropped query images for matching
4
+ against the catalog.
5
+ """
6
+
7
+ import logging
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from PIL import Image
13
+
14
+ from snowleopard_reid import features, get_device
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def run_feature_extraction_stage(
20
+ image: Image.Image | Path | str,
21
+ extractor: str = "sift",
22
+ max_keypoints: int = 2048,
23
+ device: str | None = None,
24
+ ) -> dict:
25
+ """Extract features from query image.
26
+
27
+ This stage extracts keypoints and descriptors from the preprocessed query
28
+ image using the specified feature extractor.
29
+
30
+ Args:
31
+ image: PIL Image object or path to image file
32
+ extractor: Feature extractor to use (default: 'sift')
33
+ max_keypoints: Maximum number of keypoints to extract (default: 2048)
34
+ device: Device to run on ('cpu', 'cuda', or None for auto-detect)
35
+
36
+ Returns:
37
+ Stage dict with structure:
38
+ {
39
+ "stage_id": "feature_extraction",
40
+ "stage_name": "Feature Extraction",
41
+ "description": "Extract keypoints and descriptors",
42
+ "config": {
43
+ "extractor": str,
44
+ "max_keypoints": int,
45
+ "device": str
46
+ },
47
+ "metrics": {
48
+ "num_keypoints": int
49
+ },
50
+ "data": {
51
+ "features": {
52
+ "keypoints": torch.Tensor [N, 2],
53
+ "descriptors": torch.Tensor [N, D],
54
+ "scores": torch.Tensor [N],
55
+ "image_size": torch.Tensor [2]
56
+ }
57
+ }
58
+ }
59
+
60
+ Raises:
61
+ ValueError: If extractor is not supported
62
+ FileNotFoundError: If image path doesn't exist
63
+ RuntimeError: If feature extraction fails
64
+ """
65
+ # Auto-detect device if not specified
66
+ device = get_device(device=device, verbose=True)
67
+
68
+ # Extract features using the factory function
69
+ features_dict = _extract_features_from_image(
70
+ image=image, extractor=extractor, max_keypoints=max_keypoints, device=device
71
+ )
72
+
73
+ # Get number of keypoints
74
+ num_kpts = features.get_num_keypoints(features_dict)
75
+ logger.info(f"Extracted {num_kpts} keypoints using {extractor.upper()}")
76
+
77
+ # Return standardized stage dict
78
+ return {
79
+ "stage_id": "feature_extraction",
80
+ "stage_name": "Feature Extraction",
81
+ "description": "Extract keypoints and descriptors",
82
+ "config": {
83
+ "extractor": extractor,
84
+ "max_keypoints": max_keypoints,
85
+ "device": device,
86
+ },
87
+ "metrics": {
88
+ "num_keypoints": num_kpts,
89
+ },
90
+ "data": {
91
+ "features": features_dict,
92
+ },
93
+ }
94
+
95
+
96
+ def _extract_features_from_image(
97
+ image: Image.Image | Path | str,
98
+ extractor: str,
99
+ max_keypoints: int,
100
+ device: str,
101
+ ) -> dict[str, torch.Tensor]:
102
+ """Extract features from PIL Image or path using specified extractor.
103
+
104
+ This is a wrapper that handles PIL Image input by saving to a temporary file,
105
+ since lightglue's load_image() requires a file path.
106
+
107
+ Args:
108
+ image: PIL Image or path to image
109
+ extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked')
110
+ max_keypoints: Maximum keypoints to extract
111
+ device: Device to use
112
+
113
+ Returns:
114
+ Features dictionary
115
+ """
116
+ if isinstance(image, Image.Image):
117
+ # Save PIL Image to temporary file
118
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
119
+ tmp_path = Path(tmp.name)
120
+ image.save(tmp_path, quality=95)
121
+
122
+ try:
123
+ # Extract features from temporary file
124
+ feats = features.extract_features(
125
+ extractor, tmp_path, max_keypoints, device
126
+ )
127
+ finally:
128
+ # Clean up temporary file
129
+ tmp_path.unlink()
130
+
131
+ return feats
132
+ else:
133
+ # Image is already a path
134
+ return features.extract_features(extractor, image, max_keypoints, device)
src/snowleopard_reid/pipeline/stages/mask_selection.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mask selection stage for choosing the best snow leopard mask.
2
+
3
+ This module provides logic for selecting the best mask from multiple YOLO predictions.
4
+ """
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def select_best_mask(
14
+ predictions: list[dict],
15
+ strategy: str = "confidence_area",
16
+ ) -> tuple[int, dict]:
17
+ """Select the best mask from predictions using specified strategy.
18
+
19
+ Args:
20
+ predictions: List of prediction dicts from YOLO segmentation stage
21
+ strategy: Selection strategy ('confidence_area', 'confidence', 'area', 'center')
22
+
23
+ Returns:
24
+ Tuple of (selected_index, selected_prediction)
25
+
26
+ Raises:
27
+ ValueError: If predictions list is empty or strategy is invalid
28
+ """
29
+ if not predictions:
30
+ raise ValueError("Predictions list is empty")
31
+
32
+ valid_strategies = ["confidence_area", "confidence", "area", "center"]
33
+ if strategy not in valid_strategies:
34
+ raise ValueError(
35
+ f"Invalid strategy '{strategy}'. Valid strategies: {valid_strategies}"
36
+ )
37
+
38
+ if strategy == "confidence_area":
39
+ # Select mask with highest confidence * area product
40
+ scores = []
41
+ for pred in predictions:
42
+ confidence = pred["confidence"]
43
+ mask = pred["mask"]
44
+ area = np.sum(mask > 0)
45
+ scores.append(confidence * area)
46
+ selected_idx = int(np.argmax(scores))
47
+
48
+ elif strategy == "confidence":
49
+ # Select mask with highest confidence
50
+ confidences = [pred["confidence"] for pred in predictions]
51
+ selected_idx = int(np.argmax(confidences))
52
+
53
+ elif strategy == "area":
54
+ # Select mask with largest area
55
+ areas = [np.sum(pred["mask"] > 0) for pred in predictions]
56
+ selected_idx = int(np.argmax(areas))
57
+
58
+ elif strategy == "center":
59
+ # Select mask closest to image center
60
+ # This strategy requires image size, which we can get from bbox
61
+ distances = []
62
+ for pred in predictions:
63
+ bbox = pred["bbox_xywhn"]
64
+ # Center is already normalized to [0, 1]
65
+ x_center = bbox["x_center"]
66
+ y_center = bbox["y_center"]
67
+ # Distance from image center (0.5, 0.5)
68
+ dist = np.sqrt((x_center - 0.5) ** 2 + (y_center - 0.5) ** 2)
69
+ distances.append(dist)
70
+ selected_idx = int(np.argmin(distances))
71
+
72
+ return selected_idx, predictions[selected_idx]
73
+
74
+
75
+ def run_mask_selection_stage(
76
+ predictions: list[dict],
77
+ strategy: str = "confidence_area",
78
+ ) -> dict:
79
+ """Run mask selection stage.
80
+
81
+ This stage selects the best mask from multiple YOLO predictions using
82
+ the specified selection strategy.
83
+
84
+ Args:
85
+ predictions: List of prediction dicts from segmentation stage
86
+ strategy: Selection strategy (default: 'confidence_area')
87
+
88
+ Returns:
89
+ Stage dict with structure:
90
+ {
91
+ "stage_id": "mask_selection",
92
+ "stage_name": "Mask Selection",
93
+ "description": "Select best mask from predictions",
94
+ "config": {
95
+ "strategy": str
96
+ },
97
+ "metrics": {
98
+ "num_candidates": int,
99
+ "selected_index": int,
100
+ "selected_confidence": float
101
+ },
102
+ "data": {
103
+ "selected_prediction": dict,
104
+ "metadata": {
105
+ "strategy": str,
106
+ "selected_index": int,
107
+ "num_candidates": int,
108
+ "confidence": float,
109
+ "mask_area": int
110
+ }
111
+ }
112
+ }
113
+
114
+ Raises:
115
+ ValueError: If predictions list is empty
116
+ """
117
+ logger.info(f"Selecting best mask using strategy: {strategy}")
118
+
119
+ # Select best mask
120
+ selected_idx, selected_pred = select_best_mask(predictions, strategy)
121
+
122
+ # Compute metadata
123
+ mask_area = int(np.sum(selected_pred["mask"] > 0))
124
+ confidence = selected_pred["confidence"]
125
+
126
+ logger.info(
127
+ f"Selected mask {selected_idx} (confidence={confidence:.3f}, area={mask_area})"
128
+ )
129
+
130
+ # Return standardized stage dict
131
+ return {
132
+ "stage_id": "mask_selection",
133
+ "stage_name": "Mask Selection",
134
+ "description": "Select best mask from predictions",
135
+ "config": {
136
+ "strategy": strategy,
137
+ },
138
+ "metrics": {
139
+ "num_candidates": len(predictions),
140
+ "selected_index": selected_idx,
141
+ "selected_confidence": confidence,
142
+ },
143
+ "data": {
144
+ "selected_prediction": selected_pred,
145
+ "metadata": {
146
+ "strategy": strategy,
147
+ "selected_index": selected_idx,
148
+ "num_candidates": len(predictions),
149
+ "confidence": confidence,
150
+ "mask_area": mask_area,
151
+ },
152
+ },
153
+ }
src/snowleopard_reid/pipeline/stages/matching.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Matching stage for snow leopard identification.
2
+
3
+ This module handles matching query features against the catalog using
4
+ LightGlue and computing matching metrics.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ from lightglue import LightGlue
14
+ from scipy.stats import wasserstein_distance
15
+
16
+ from snowleopard_reid import get_device
17
+ from snowleopard_reid.catalog import (
18
+ get_all_catalog_features,
19
+ get_catalog_metadata_for_id,
20
+ get_filtered_catalog_features,
21
+ load_catalog_index,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def run_matching_stage(
28
+ query_features: dict[str, torch.Tensor],
29
+ catalog_path: Path | str,
30
+ top_k: int = 5,
31
+ extractor: str = "sift",
32
+ device: str | None = None,
33
+ query_image_path: str | None = None,
34
+ pairwise_output_dir: Path | None = None,
35
+ filter_locations: list[str] | None = None,
36
+ filter_body_parts: list[str] | None = None,
37
+ ) -> dict:
38
+ """Match query against catalog.
39
+
40
+ This stage matches the query features against all catalog images using
41
+ LightGlue, computes metrics, and ranks matches.
42
+
43
+ Args:
44
+ query_features: Query features dict with keypoints, descriptors, scores
45
+ catalog_path: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
46
+ top_k: Number of top matches to return (default: 5)
47
+ extractor: Feature extractor used (default: 'sift')
48
+ device: Device to run matching on ('cpu', 'cuda', or None for auto-detect)
49
+ query_image_path: Path to query image (optional, for pairwise data)
50
+ pairwise_output_dir: Directory to save pairwise match data (optional)
51
+ filter_locations: List of locations to filter catalog by (e.g., ["skycrest_valley"])
52
+ filter_body_parts: List of body parts to filter catalog by (e.g., ["head", "right_flank"])
53
+
54
+ Returns:
55
+ Stage dict with structure:
56
+ {
57
+ "stage_id": "matching",
58
+ "stage_name": "Matching",
59
+ "description": "Match query against catalog using LightGlue",
60
+ "config": {...},
61
+ "metrics": {...},
62
+ "data": {
63
+ "catalog_info": {...},
64
+ "matches": [...]
65
+ }
66
+ }
67
+
68
+ Raises:
69
+ FileNotFoundError: If catalog doesn't exist
70
+ ValueError: If extractor not available in catalog
71
+ RuntimeError: If matching fails
72
+ """
73
+ catalog_path = Path(catalog_path)
74
+
75
+ if not catalog_path.exists():
76
+ raise FileNotFoundError(f"Catalog not found: {catalog_path}")
77
+
78
+ # Auto-detect device
79
+ device = get_device(device=device, verbose=True)
80
+
81
+ # Load catalog index
82
+ logger.info(f"Loading catalog from {catalog_path}")
83
+ catalog_index = load_catalog_index(catalog_path)
84
+ logger.info(
85
+ f"Catalog v{catalog_index['catalog_version']}: "
86
+ f"{catalog_index['statistics']['total_individuals']} individuals, "
87
+ f"{catalog_index['statistics']['total_reference_images']} images"
88
+ )
89
+
90
+ # Load catalog features (with optional filtering)
91
+ if filter_locations or filter_body_parts:
92
+ filter_info = []
93
+ if filter_locations:
94
+ filter_info.append(f"locations={filter_locations}")
95
+ if filter_body_parts:
96
+ filter_info.append(f"body_parts={filter_body_parts}")
97
+ logger.info(f"Loading filtered catalog features ({', '.join(filter_info)})")
98
+ try:
99
+ catalog_features = get_filtered_catalog_features(
100
+ catalog_root=catalog_path,
101
+ extractor=extractor,
102
+ locations=filter_locations,
103
+ body_parts=filter_body_parts,
104
+ )
105
+ logger.info(f"Loaded {len(catalog_features)} filtered catalog features")
106
+ except ValueError as e:
107
+ raise ValueError(f"Failed to load filtered catalog features: {e}")
108
+ else:
109
+ logger.info(f"Loading catalog features (extractor: {extractor})")
110
+ try:
111
+ catalog_features = get_all_catalog_features(
112
+ catalog_root=catalog_path, extractor=extractor
113
+ )
114
+ logger.info(f"Loaded {len(catalog_features)} catalog features")
115
+ except ValueError as e:
116
+ raise ValueError(f"Failed to load catalog features: {e}")
117
+
118
+ # Initialize LightGlue matcher
119
+ logger.info(f"Initializing LightGlue matcher with {extractor} features")
120
+ try:
121
+ matcher = LightGlue(features=extractor).eval().to(device)
122
+ except Exception as e:
123
+ raise ValueError(
124
+ f"Failed to initialize LightGlue matcher with extractor '{extractor}': {e}"
125
+ )
126
+
127
+ # Move query features to device and add batch dimension
128
+ query_feats = {}
129
+ for k, v in query_features.items():
130
+ if isinstance(v, torch.Tensor):
131
+ # Add batch dimension if not present
132
+ if v.ndim == 1:
133
+ v = v.unsqueeze(0)
134
+ elif v.ndim == 2:
135
+ v = v.unsqueeze(0)
136
+ query_feats[k] = v.to(device)
137
+ else:
138
+ query_feats[k] = v
139
+
140
+ # Serial matching: iterate through catalog
141
+ logger.info(f"Matching against {len(catalog_features)} catalog images")
142
+ matches_dict = {}
143
+ raw_matches_cache = {} # Store raw matches for pairwise saving
144
+
145
+ for catalog_id, catalog_feats in catalog_features.items():
146
+ # Move catalog features to device and add batch dimension
147
+ catalog_feats_device = {}
148
+ for k, v in catalog_feats.items():
149
+ if isinstance(v, torch.Tensor):
150
+ # Add batch dimension if not present
151
+ if v.ndim == 1:
152
+ v = v.unsqueeze(0)
153
+ elif v.ndim == 2:
154
+ v = v.unsqueeze(0)
155
+ catalog_feats_device[k] = v.to(device)
156
+ else:
157
+ catalog_feats_device[k] = v
158
+
159
+ # Run matcher
160
+ try:
161
+ with torch.no_grad():
162
+ matches = matcher(
163
+ {
164
+ "image0": query_feats,
165
+ "image1": catalog_feats_device,
166
+ }
167
+ )
168
+ except Exception as e:
169
+ logger.warning(f"Matching failed for {catalog_id}: {e}")
170
+ continue
171
+
172
+ # Compute metrics
173
+ try:
174
+ metrics = compute_match_metrics(matches)
175
+ matches_dict[catalog_id] = metrics
176
+
177
+ # Cache raw matches and features for top-k pairwise saving
178
+ if pairwise_output_dir is not None:
179
+ raw_matches_cache[catalog_id] = {
180
+ "matches": matches,
181
+ "catalog_features": catalog_feats,
182
+ }
183
+ except KeyError as e:
184
+ logger.warning(f"Failed to compute metrics for {catalog_id}: {e}")
185
+ continue
186
+
187
+ logger.info(f"Successfully matched against {len(matches_dict)} catalog images")
188
+
189
+ if not matches_dict:
190
+ raise RuntimeError(
191
+ "No successful matches found. All catalog images failed to match. "
192
+ "This may indicate a problem with feature extraction or format."
193
+ )
194
+
195
+ # Rank matches by Wasserstein distance
196
+ ranked_matches = rank_matches(matches_dict, metric="wasserstein", top_k=top_k)
197
+
198
+ # Enrich matches with catalog metadata
199
+ enriched_matches = []
200
+ for match in ranked_matches:
201
+ catalog_id = match["catalog_id"]
202
+ metadata = get_catalog_metadata_for_id(
203
+ catalog_root=catalog_path, catalog_id=catalog_id
204
+ )
205
+
206
+ if metadata is None:
207
+ logger.warning(f"No metadata found for {catalog_id}")
208
+ continue
209
+
210
+ enriched_match = {
211
+ "rank": match["rank"],
212
+ "catalog_id": catalog_id,
213
+ "leopard_name": metadata["leopard_name"],
214
+ "filepath": str(metadata["image_path"]),
215
+ "wasserstein": match["wasserstein"],
216
+ "auc": match["auc"],
217
+ "num_matches": match["num_matches"],
218
+ "individual_id": metadata["individual_id"],
219
+ }
220
+ enriched_matches.append(enriched_match)
221
+
222
+ if enriched_matches:
223
+ logger.info(
224
+ f"Top match: {enriched_matches[0]['leopard_name']} "
225
+ f"(wasserstein: {enriched_matches[0]['wasserstein']:.4f}, "
226
+ f"matches: {enriched_matches[0]['num_matches']})"
227
+ )
228
+
229
+ # Save pairwise match data for top-k matches
230
+ if pairwise_output_dir is not None and enriched_matches:
231
+ pairwise_output_dir = Path(pairwise_output_dir)
232
+ pairwise_output_dir.mkdir(parents=True, exist_ok=True)
233
+
234
+ logger.info(
235
+ f"Saving pairwise match data for top-{len(enriched_matches)} matches"
236
+ )
237
+
238
+ # Get query image size
239
+ query_image_size = query_features.get("image_size")
240
+ if isinstance(query_image_size, torch.Tensor):
241
+ query_image_size = query_image_size.cpu().numpy()
242
+
243
+ for enriched_match in enriched_matches:
244
+ catalog_id = enriched_match["catalog_id"]
245
+
246
+ # Skip if no cached data for this catalog_id
247
+ if catalog_id not in raw_matches_cache:
248
+ logger.warning(
249
+ f"No cached match data for {catalog_id}, skipping pairwise save"
250
+ )
251
+ enriched_match["pairwise_file"] = None
252
+ continue
253
+
254
+ # Get cached data
255
+ cached = raw_matches_cache[catalog_id]
256
+ matches = cached["matches"]
257
+ catalog_feats = cached["catalog_features"]
258
+
259
+ # Extract matched keypoints
260
+ try:
261
+ matched_data = extract_matched_keypoints(
262
+ query_features=query_features,
263
+ catalog_features=catalog_feats,
264
+ matches=matches,
265
+ )
266
+ except Exception as e:
267
+ logger.warning(
268
+ f"Failed to extract keypoints for {catalog_id}: {e}, skipping"
269
+ )
270
+ enriched_match["pairwise_file"] = None
271
+ continue
272
+
273
+ # Get catalog image size
274
+ catalog_image_size = catalog_feats.get("image_size")
275
+ if isinstance(catalog_image_size, torch.Tensor):
276
+ catalog_image_size = catalog_image_size.cpu().numpy()
277
+
278
+ # Build pairwise data
279
+ pairwise_data = {
280
+ "rank": enriched_match["rank"],
281
+ "catalog_id": catalog_id,
282
+ "leopard_name": enriched_match["leopard_name"],
283
+ "query_image_path": query_image_path or "",
284
+ "catalog_image_path": enriched_match["filepath"],
285
+ "query_image_size": query_image_size,
286
+ "catalog_image_size": catalog_image_size,
287
+ "query_keypoints": matched_data["query_keypoints"],
288
+ "catalog_keypoints": matched_data["catalog_keypoints"],
289
+ "match_scores": matched_data["match_scores"],
290
+ "wasserstein": enriched_match["wasserstein"],
291
+ "auc": enriched_match["auc"],
292
+ "num_matches": matched_data[
293
+ "num_matches"
294
+ ], # Use actual count from extracted keypoints
295
+ }
296
+
297
+ # Save as compressed NPZ
298
+ output_filename = f"rank_{enriched_match['rank']:02d}_{catalog_id}.npz"
299
+ output_path = pairwise_output_dir / output_filename
300
+
301
+ np.savez_compressed(output_path, **pairwise_data)
302
+
303
+ # Add pairwise file reference to enriched_match (relative to matching stage dir)
304
+ enriched_match["pairwise_file"] = f"pairwise/{output_filename}"
305
+
306
+ logger.info(f"Saved pairwise data to {pairwise_output_dir}")
307
+ else:
308
+ # Set pairwise_file to None if not saving pairwise data
309
+ for enriched_match in enriched_matches:
310
+ enriched_match["pairwise_file"] = None
311
+
312
+ # Return standardized stage dict
313
+ return {
314
+ "stage_id": "matching",
315
+ "stage_name": "Matching",
316
+ "description": "Match query against catalog using LightGlue",
317
+ "config": {
318
+ "top_k": top_k,
319
+ "extractor": extractor,
320
+ "device": device,
321
+ "catalog_path": str(catalog_path),
322
+ "filter_locations": filter_locations,
323
+ "filter_body_parts": filter_body_parts,
324
+ },
325
+ "metrics": {
326
+ "num_catalog_images": len(catalog_features),
327
+ "num_successful_matches": len(matches_dict),
328
+ "top_match_wasserstein": enriched_matches[0]["wasserstein"]
329
+ if enriched_matches
330
+ else 0.0,
331
+ "top_match_leopard_name": enriched_matches[0]["leopard_name"]
332
+ if enriched_matches
333
+ else "",
334
+ },
335
+ "data": {
336
+ "catalog_info": {
337
+ "catalog_version": catalog_index["catalog_version"],
338
+ "catalog_path": str(catalog_path),
339
+ "num_individuals": catalog_index["statistics"]["total_individuals"],
340
+ "num_reference_images": catalog_index["statistics"][
341
+ "total_reference_images"
342
+ ],
343
+ },
344
+ "matches": enriched_matches,
345
+ },
346
+ }
347
+
348
+
349
+ # ============================================================================
350
+ # Metrics Utilities
351
+ # ============================================================================
352
+
353
+
354
+ def compute_wasserstein_distance(scores: np.ndarray) -> float:
355
+ """Compute Wasserstein distance from null distribution.
356
+
357
+ The Wasserstein distance measures how far the match score distribution is from
358
+ a null distribution (all zeros). Higher values indicate better matches.
359
+ This is the optimal metric for re-identification tasks.
360
+
361
+ Args:
362
+ scores: Array of match scores (typically from matcher output)
363
+
364
+ Returns:
365
+ Wasserstein distance as a float
366
+
367
+ References:
368
+ Based on trout-reid implementation for animal re-identification
369
+ """
370
+ if len(scores) == 0:
371
+ return 0.0
372
+
373
+ # Null distribution: fixed-length array of zeros
374
+ # This represents no matches at all
375
+ # Using fixed length (1024) ensures all matches are comparable
376
+ # to the same reference distribution (follows trout-reID implementation)
377
+ x_null_distribution = np.zeros(1024)
378
+
379
+ # Compute Wasserstein (Earth Mover's) distance
380
+ distance = wasserstein_distance(x_null_distribution, scores)
381
+
382
+ return float(distance)
383
+
384
+
385
+ def compute_auc(scores: np.ndarray) -> float:
386
+ """Compute Area Under Curve (cumulative distribution) of match scores.
387
+
388
+ AUC represents the cumulative distribution of match scores.
389
+ Higher values indicate better matches.
390
+
391
+ Args:
392
+ scores: Array of match scores (typically from matcher output)
393
+
394
+ Returns:
395
+ AUC value as a float (0.0 to 1.0)
396
+
397
+ References:
398
+ Based on trout-reid implementation
399
+ """
400
+ if len(scores) == 0:
401
+ return 0.0
402
+
403
+ # Sort scores in ascending order
404
+ sorted_scores = np.sort(scores)
405
+
406
+ # Compute cumulative sum
407
+ cumsum = np.cumsum(sorted_scores)
408
+
409
+ # Normalize by total sum to get AUC in [0, 1]
410
+ if cumsum[-1] > 0:
411
+ auc = np.trapz(cumsum / cumsum[-1]) / len(scores)
412
+ else:
413
+ auc = 0.0
414
+
415
+ return float(auc)
416
+
417
+
418
+ def extract_match_scores(matches: dict[str, torch.Tensor]) -> np.ndarray:
419
+ """Extract match scores from matcher output.
420
+
421
+ Args:
422
+ matches: Dictionary from LightGlue matcher with keys:
423
+ - matches0: Tensor of matched indices
424
+ - matching_scores0: Tensor of match confidence scores
425
+
426
+ Returns:
427
+ Numpy array of match scores
428
+
429
+ Raises:
430
+ KeyError: If required keys are missing from matches dict
431
+ """
432
+ if "matching_scores0" not in matches:
433
+ raise KeyError("matches dictionary missing 'matching_scores0' key")
434
+
435
+ scores = matches["matching_scores0"]
436
+
437
+ # Convert to numpy and filter out invalid matches (-1 values)
438
+ if isinstance(scores, torch.Tensor):
439
+ scores = scores.cpu().numpy()
440
+
441
+ # Filter out unmatched keypoints (score = 0 or negative)
442
+ valid_scores = scores[scores > 0]
443
+
444
+ return valid_scores
445
+
446
+
447
+ def extract_matched_keypoints(
448
+ query_features: dict[str, torch.Tensor],
449
+ catalog_features: dict[str, torch.Tensor],
450
+ matches: dict[str, torch.Tensor],
451
+ ) -> dict[str, np.ndarray]:
452
+ """Extract matched keypoint pairs from matcher output.
453
+
454
+ Args:
455
+ query_features: Query feature dict with 'keypoints' tensor [M, 2]
456
+ catalog_features: Catalog feature dict with 'keypoints' tensor [N, 2]
457
+ matches: Dictionary from LightGlue matcher with:
458
+ - matches0: Tensor [M] mapping query_idx → catalog_idx (-1 if no match)
459
+ - matching_scores0: Tensor [M] with match confidence scores
460
+
461
+ Returns:
462
+ Dictionary with:
463
+ - query_keypoints: ndarray [num_matches, 2] - matched query keypoints
464
+ - catalog_keypoints: ndarray [num_matches, 2] - matched catalog keypoints
465
+ - match_scores: ndarray [num_matches] - confidence scores
466
+ - num_matches: int - number of valid matches
467
+
468
+ Raises:
469
+ KeyError: If required keys are missing
470
+ """
471
+ if "matches0" not in matches or "matching_scores0" not in matches:
472
+ raise KeyError(
473
+ "matches dictionary missing 'matches0' or 'matching_scores0' keys"
474
+ )
475
+
476
+ # Get match indices and scores
477
+ matches0 = matches["matches0"] # Shape: [M]
478
+ scores0 = matches["matching_scores0"] # Shape: [M]
479
+
480
+ # Convert to numpy if tensors
481
+ if isinstance(matches0, torch.Tensor):
482
+ matches0 = matches0.cpu().numpy()
483
+ if isinstance(scores0, torch.Tensor):
484
+ scores0 = scores0.cpu().numpy()
485
+
486
+ # Remove batch dimension if present
487
+ if matches0.ndim == 2:
488
+ matches0 = matches0[0]
489
+ if scores0.ndim == 2:
490
+ scores0 = scores0[0]
491
+
492
+ # Filter valid matches (matched and score > 0)
493
+ valid_mask = (matches0 >= 0) & (scores0 > 0)
494
+ valid_indices = matches0[valid_mask].astype(int)
495
+ valid_scores = scores0[valid_mask]
496
+
497
+ # Get keypoints
498
+ query_kpts = query_features["keypoints"]
499
+ catalog_kpts = catalog_features["keypoints"]
500
+
501
+ # Convert to numpy if tensors
502
+ if isinstance(query_kpts, torch.Tensor):
503
+ query_kpts = query_kpts.cpu().numpy()
504
+ if isinstance(catalog_kpts, torch.Tensor):
505
+ catalog_kpts = catalog_kpts.cpu().numpy()
506
+
507
+ # Remove batch dimension if present
508
+ if query_kpts.ndim == 3:
509
+ query_kpts = query_kpts[0]
510
+ if catalog_kpts.ndim == 3:
511
+ catalog_kpts = catalog_kpts[0]
512
+
513
+ # Extract matched keypoints
514
+ query_matched = query_kpts[valid_mask]
515
+ catalog_matched = catalog_kpts[valid_indices]
516
+
517
+ return {
518
+ "query_keypoints": query_matched,
519
+ "catalog_keypoints": catalog_matched,
520
+ "match_scores": valid_scores,
521
+ "num_matches": len(valid_scores),
522
+ }
523
+
524
+
525
+ def rank_matches(
526
+ matches_dict: dict[str, dict[str, Any]],
527
+ metric: str = "wasserstein",
528
+ top_k: int = None,
529
+ ) -> list[dict[str, Any]]:
530
+ """Rank matches by specified metric.
531
+
532
+ Args:
533
+ matches_dict: Dictionary mapping catalog_id to match info
534
+ metric: Metric to rank by ('wasserstein' or 'auc')
535
+ top_k: Number of top matches to return (None = all)
536
+
537
+ Returns:
538
+ List of match dictionaries sorted by metric (best first)
539
+
540
+ Raises:
541
+ ValueError: If metric is not supported
542
+ """
543
+ if metric not in ["wasserstein", "auc"]:
544
+ raise ValueError(f"Unsupported metric: {metric}. Use 'wasserstein' or 'auc'")
545
+
546
+ # Convert dict to list with catalog_id included
547
+ matches_list = [
548
+ {"catalog_id": cid, **match_info} for cid, match_info in matches_dict.items()
549
+ ]
550
+
551
+ # Sort by metric (descending - higher is better for both metrics)
552
+ sorted_matches = sorted(
553
+ matches_list,
554
+ key=lambda x: x.get(metric, 0.0),
555
+ reverse=True,
556
+ )
557
+
558
+ # Add rank
559
+ for rank, match in enumerate(sorted_matches, start=1):
560
+ match["rank"] = rank
561
+
562
+ # Return top_k if specified
563
+ if top_k is not None:
564
+ return sorted_matches[:top_k]
565
+
566
+ return sorted_matches
567
+
568
+
569
+ def compute_match_metrics(matches: dict[str, torch.Tensor]) -> dict[str, float]:
570
+ """Compute all matching metrics for a single match result.
571
+
572
+ Args:
573
+ matches: Dictionary from LightGlue matcher
574
+
575
+ Returns:
576
+ Dictionary with computed metrics:
577
+ - wasserstein: float
578
+ - auc: float
579
+ - num_matches: int
580
+
581
+ Raises:
582
+ KeyError: If matches dict is missing required keys
583
+ """
584
+ try:
585
+ scores = extract_match_scores(matches)
586
+
587
+ return {
588
+ "wasserstein": compute_wasserstein_distance(scores),
589
+ "auc": compute_auc(scores),
590
+ "num_matches": len(scores),
591
+ }
592
+ except KeyError as e:
593
+ raise KeyError(f"Failed to compute metrics: {e}")
src/snowleopard_reid/pipeline/stages/preprocess.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocessing stage for cropping and masking snow leopard images.
2
+
3
+ This module provides preprocessing operations to extract and mask the leopard region
4
+ from the full image.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def run_preprocess_stage(
18
+ image_path: Path | str,
19
+ mask: np.ndarray,
20
+ padding: int = 5,
21
+ ) -> dict:
22
+ """Run preprocessing stage.
23
+
24
+ This stage crops the image to the mask bounding box with padding and applies
25
+ the mask to isolate the leopard region.
26
+
27
+ Args:
28
+ image_path: Path to input image
29
+ mask: Binary mask (H×W, uint8) from segmentation
30
+ padding: Padding around bbox in pixels (default: 5)
31
+
32
+ Returns:
33
+ Stage dict with structure:
34
+ {
35
+ "stage_id": "preprocessing",
36
+ "stage_name": "Preprocessing",
37
+ "description": "Crop and mask leopard region",
38
+ "config": {
39
+ "padding": int
40
+ },
41
+ "metrics": {
42
+ "original_size": {"width": int, "height": int},
43
+ "crop_size": {"width": int, "height": int}
44
+ },
45
+ "data": {
46
+ "cropped_image": PIL.Image,
47
+ "metadata": {
48
+ "original_size": {"width": int, "height": int},
49
+ "crop_bbox": {"x_min": int, "y_min": int, "x_max": int, "y_max": int},
50
+ "crop_size": {"width": int, "height": int}
51
+ }
52
+ }
53
+ }
54
+
55
+ Raises:
56
+ FileNotFoundError: If image doesn't exist
57
+ ValueError: If mask is invalid
58
+ """
59
+ image_path = Path(image_path)
60
+
61
+ if not image_path.exists():
62
+ raise FileNotFoundError(f"Image not found: {image_path}")
63
+
64
+ logger.info(f"Preprocessing image: {image_path}")
65
+
66
+ # Load image
67
+ image = cv2.imread(str(image_path))
68
+ if image is None:
69
+ raise RuntimeError(f"Failed to load image: {image_path}")
70
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
71
+ image_height, image_width = image_rgb.shape[:2]
72
+
73
+ # Resize mask to match image dimensions if needed
74
+ if mask.shape[:2] != (image_height, image_width):
75
+ mask_resized = cv2.resize(
76
+ mask.astype(np.uint8),
77
+ (image_width, image_height),
78
+ interpolation=cv2.INTER_NEAREST,
79
+ )
80
+ else:
81
+ mask_resized = mask
82
+
83
+ # Find bounding box of mask
84
+ rows = np.any(mask_resized > 0, axis=1)
85
+ cols = np.any(mask_resized > 0, axis=0)
86
+
87
+ if not np.any(rows) or not np.any(cols):
88
+ raise ValueError("Mask is empty (no pixels > 0)")
89
+
90
+ y_min, y_max = np.where(rows)[0][[0, -1]]
91
+ x_min, x_max = np.where(cols)[0][[0, -1]]
92
+
93
+ # Add padding
94
+ x_min = max(0, x_min - padding)
95
+ y_min = max(0, y_min - padding)
96
+ x_max = min(image_width - 1, x_max + padding)
97
+ y_max = min(image_height - 1, y_max + padding)
98
+
99
+ # Crop image and mask
100
+ cropped_image = image_rgb[y_min : y_max + 1, x_min : x_max + 1]
101
+ cropped_mask = mask_resized[y_min : y_max + 1, x_min : x_max + 1]
102
+
103
+ # Apply mask (set non-masked pixels to black)
104
+ masked_image = cropped_image.copy()
105
+ masked_image[cropped_mask == 0] = 0
106
+
107
+ # Convert to PIL Image
108
+ cropped_pil = Image.fromarray(masked_image)
109
+
110
+ crop_height, crop_width = masked_image.shape[:2]
111
+
112
+ logger.info(
113
+ f"Cropped from {image_width}x{image_height} to {crop_width}x{crop_height} "
114
+ f"(padding={padding}px)"
115
+ )
116
+
117
+ # Return standardized stage dict
118
+ return {
119
+ "stage_id": "preprocessing",
120
+ "stage_name": "Preprocessing",
121
+ "description": "Crop and mask leopard region",
122
+ "config": {
123
+ "padding": padding,
124
+ },
125
+ "metrics": {
126
+ "original_size": {"width": image_width, "height": image_height},
127
+ "crop_size": {"width": crop_width, "height": crop_height},
128
+ },
129
+ "data": {
130
+ "cropped_image": cropped_pil,
131
+ "metadata": {
132
+ "original_size": {"width": image_width, "height": image_height},
133
+ "crop_bbox": {
134
+ "x_min": int(x_min),
135
+ "y_min": int(y_min),
136
+ "x_max": int(x_max),
137
+ "y_max": int(y_max),
138
+ },
139
+ "crop_size": {"width": crop_width, "height": crop_height},
140
+ },
141
+ },
142
+ }
src/snowleopard_reid/pipeline/stages/segmentation.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segmentation stage using YOLO or GDINO+SAM for snow leopard detection.
2
+
3
+ This module provides segmentation stages that detect and segment snow leopards
4
+ in query images using either:
5
+ 1. YOLO (end-to-end learned segmentation)
6
+ 2. GDINO+SAM (zero-shot detection + prompted segmentation)
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+ from segment_anything_hq import SamPredictor, sam_model_registry
18
+ from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
19
+ from ultralytics import YOLO
20
+
21
+ from snowleopard_reid import get_device
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def load_gdino_model(
27
+ model_id: str = "IDEA-Research/grounding-dino-base",
28
+ device: str | None = None,
29
+ ) -> tuple[Any, Any]:
30
+ """Load Grounding DINO model and processor.
31
+
32
+ Args:
33
+ model_id: HuggingFace model identifier
34
+ device: Device to load model on (None = auto-detect)
35
+
36
+ Returns:
37
+ Tuple of (processor, model)
38
+ """
39
+ device = get_device(device=device, verbose=True)
40
+
41
+ logger.info(f"Loading Grounding DINO model: {model_id}")
42
+ processor = AutoProcessor.from_pretrained(model_id)
43
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
44
+ model = model.to(device)
45
+ model.eval()
46
+
47
+ logger.info("Grounding DINO model loaded successfully")
48
+ return processor, model
49
+
50
+
51
+ def load_sam_predictor(
52
+ checkpoint_path: Path | str,
53
+ model_type: str = "vit_l",
54
+ device: str | None = None,
55
+ ) -> SamPredictor:
56
+ """Load SAM HQ predictor.
57
+
58
+ Args:
59
+ checkpoint_path: Path to SAM HQ checkpoint file
60
+ model_type: Model type (vit_b, vit_l, vit_h)
61
+ device: Device to load model on (None = auto-detect)
62
+
63
+ Returns:
64
+ SamPredictor instance
65
+ """
66
+ checkpoint_path = Path(checkpoint_path)
67
+ if not checkpoint_path.exists():
68
+ raise FileNotFoundError(f"SAM checkpoint not found: {checkpoint_path}")
69
+
70
+ device_str = get_device(device=device, verbose=True)
71
+
72
+ logger.info(f"Loading SAM HQ model: {model_type}")
73
+ sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
74
+ sam.to(device=device_str)
75
+
76
+ predictor = SamPredictor(sam)
77
+ logger.info("SAM HQ model loaded successfully")
78
+
79
+ return predictor
80
+
81
+
82
+ def _run_yolo_segmentation(
83
+ model: YOLO,
84
+ image_path: Path,
85
+ confidence_threshold: float,
86
+ device: str,
87
+ ) -> dict:
88
+ """Run YOLO segmentation (internal implementation).
89
+
90
+ Args:
91
+ model: Pre-loaded YOLO model
92
+ image_path: Path to input image
93
+ confidence_threshold: Minimum confidence to keep predictions
94
+ device: Device to run on
95
+
96
+ Returns:
97
+ Standardized stage dict
98
+ """
99
+ # Load image to get size
100
+ image = cv2.imread(str(image_path))
101
+ if image is None:
102
+ raise RuntimeError(f"Failed to load image: {image_path}")
103
+ image_height, image_width = image.shape[:2]
104
+
105
+ # Run inference
106
+ try:
107
+ results = model(
108
+ str(image_path),
109
+ conf=confidence_threshold,
110
+ device=device,
111
+ verbose=False,
112
+ )
113
+ except Exception as e:
114
+ raise RuntimeError(f"YOLO inference failed: {e}")
115
+
116
+ # Parse results
117
+ predictions = []
118
+ result = results[0] # Single image, so single result
119
+
120
+ # Debug: Print result attributes
121
+ logger.info(f"Result object type: {type(result)}")
122
+ logger.info(f"Result has boxes: {result.boxes is not None}")
123
+ logger.info(f"Result has masks: {result.masks is not None}")
124
+ if result.boxes is not None:
125
+ logger.info(f"Number of boxes: {len(result.boxes)}")
126
+ if result.masks is not None:
127
+ logger.info(f"Number of masks: {len(result.masks)}")
128
+
129
+ # Check if any detections found
130
+ if result.masks is None or len(result.masks) == 0:
131
+ logger.warning(f"No detections found for {image_path}")
132
+ logger.warning(
133
+ f"Boxes present: {result.boxes is not None}, Masks present: {result.masks is not None}"
134
+ )
135
+ else:
136
+ # Extract masks and metadata
137
+ for idx in range(len(result.masks)):
138
+ # Get mask (binary, H×W)
139
+ mask = result.masks.data[idx].cpu().numpy() # Shape: (H, W)
140
+ mask = (mask * 255).astype(np.uint8) # Convert to 0-255
141
+
142
+ # Get bounding box (normalized xywh format)
143
+ bbox = result.boxes.xywhn[idx].cpu().numpy() # Shape: (4,)
144
+ x_center, y_center, width, height = bbox
145
+
146
+ # Get confidence
147
+ confidence = float(result.boxes.conf[idx].cpu().numpy())
148
+
149
+ # Get class info
150
+ class_id = int(result.boxes.cls[idx].cpu().numpy())
151
+ class_name = result.names[class_id]
152
+
153
+ predictions.append(
154
+ {
155
+ "mask": mask,
156
+ "confidence": confidence,
157
+ "bbox_xywhn": {
158
+ "x_center": float(x_center),
159
+ "y_center": float(y_center),
160
+ "width": float(width),
161
+ "height": float(height),
162
+ },
163
+ "class_id": class_id,
164
+ "class_name": class_name,
165
+ }
166
+ )
167
+
168
+ logger.info(
169
+ f"Found {len(predictions)} predictions (confidence >= {confidence_threshold})"
170
+ )
171
+
172
+ # Return standardized stage dict
173
+ return {
174
+ "stage_id": "segmentation",
175
+ "stage_name": "YOLO Segmentation",
176
+ "description": "Snow leopard detection and segmentation using YOLO",
177
+ "config": {
178
+ "strategy": "yolo",
179
+ "confidence_threshold": confidence_threshold,
180
+ "device": device,
181
+ },
182
+ "metrics": {
183
+ "num_predictions": len(predictions),
184
+ },
185
+ "data": {
186
+ "image_path": str(image_path),
187
+ "image_size": {"width": image_width, "height": image_height},
188
+ "predictions": predictions,
189
+ },
190
+ }
191
+
192
+
193
+ def _run_gdino_sam_segmentation(
194
+ gdino_processor: Any,
195
+ gdino_model: Any,
196
+ sam_predictor: SamPredictor,
197
+ image_path: Path,
198
+ confidence_threshold: float,
199
+ text_prompt: str,
200
+ box_threshold: float,
201
+ text_threshold: float,
202
+ device: str,
203
+ ) -> dict:
204
+ """Run GDINO+SAM segmentation (internal implementation).
205
+
206
+ Args:
207
+ gdino_processor: Grounding DINO processor
208
+ gdino_model: Grounding DINO model
209
+ sam_predictor: SAM HQ predictor
210
+ image_path: Path to input image
211
+ confidence_threshold: Minimum confidence to keep predictions
212
+ text_prompt: Text prompt for GDINO
213
+ box_threshold: GDINO box threshold
214
+ text_threshold: GDINO text threshold
215
+ device: Device to run on
216
+
217
+ Returns:
218
+ Standardized stage dict
219
+ """
220
+ # Load image (PIL for GDINO, numpy for SAM)
221
+ image_pil = Image.open(image_path).convert("RGB")
222
+ image_np = np.array(image_pil)
223
+ image_height, image_width = image_np.shape[:2]
224
+
225
+ # Run Grounding DINO detection
226
+ logger.info("Running Grounding DINO detection...")
227
+ inputs = gdino_processor(images=image_pil, text=text_prompt, return_tensors="pt")
228
+ inputs = {k: v.to(device) for k, v in inputs.items()}
229
+
230
+ with torch.no_grad():
231
+ outputs = gdino_model(**inputs)
232
+
233
+ # Post-process GDINO outputs
234
+ results = gdino_processor.post_process_grounded_object_detection(
235
+ outputs,
236
+ inputs["input_ids"],
237
+ threshold=box_threshold,
238
+ text_threshold=text_threshold,
239
+ target_sizes=[image_pil.size[::-1]], # (height, width)
240
+ )[0]
241
+
242
+ # Filter by confidence threshold
243
+ labels = results.get("text_labels", results.get("labels", []))
244
+ boxes = results["boxes"]
245
+ scores = results["scores"]
246
+
247
+ logger.info(f"GDINO detected {len(boxes)} objects")
248
+
249
+ # Filter predictions by confidence threshold
250
+ filtered_detections = [
251
+ (box, score, label)
252
+ for box, score, label in zip(boxes, scores, labels)
253
+ if float(score) >= confidence_threshold
254
+ ]
255
+
256
+ logger.info(
257
+ f"Filtered to {len(filtered_detections)} detections (confidence >= {confidence_threshold})"
258
+ )
259
+
260
+ if not filtered_detections:
261
+ logger.warning(f"No detections found for {image_path}")
262
+ predictions = []
263
+ else:
264
+ # Set image for SAM (do this once)
265
+ logger.info("Running SAM HQ segmentation...")
266
+ sam_predictor.set_image(image_np)
267
+
268
+ predictions = []
269
+ for idx, (box, gdino_score, label) in enumerate(filtered_detections):
270
+ # Convert box to pixel coordinates and format for SAM
271
+ x_min, y_min, x_max, y_max = box
272
+ bbox_xyxy = np.array(
273
+ [float(x_min), float(y_min), float(x_max), float(y_max)]
274
+ )
275
+
276
+ # Run SAM with bounding box prompt
277
+ masks, sam_scores, logits = sam_predictor.predict(
278
+ box=bbox_xyxy[None, :],
279
+ multimask_output=False,
280
+ hq_token_only=True,
281
+ )
282
+
283
+ # Get mask (first and only mask, since multimask_output=False)
284
+ mask = masks[0] # Shape: (H, W), boolean
285
+ sam_score = float(sam_scores[0])
286
+
287
+ # Convert mask to uint8 (0-255)
288
+ mask_uint8 = (mask * 255).astype(np.uint8)
289
+
290
+ # Convert bbox to normalized xywh format (same as YOLO)
291
+ x_center = (float(x_min) + float(x_max)) / 2 / image_width
292
+ y_center = (float(y_min) + float(y_max)) / 2 / image_height
293
+ width = (float(x_max) - float(x_min)) / image_width
294
+ height = (float(y_max) - float(y_min)) / image_height
295
+
296
+ predictions.append(
297
+ {
298
+ "mask": mask_uint8,
299
+ "confidence": float(
300
+ gdino_score
301
+ ), # Use GDINO score as primary confidence
302
+ "bbox_xywhn": {
303
+ "x_center": x_center,
304
+ "y_center": y_center,
305
+ "width": width,
306
+ "height": height,
307
+ },
308
+ "class_id": 0, # Single class (snow leopard)
309
+ "class_name": label,
310
+ # Additional metadata
311
+ "sam_score": sam_score,
312
+ "gdino_score": float(gdino_score),
313
+ }
314
+ )
315
+
316
+ logger.info(f"Generated {len(predictions)} segmentation masks")
317
+
318
+ # Return standardized stage dict
319
+ return {
320
+ "stage_id": "segmentation",
321
+ "stage_name": "GDINO+SAM Segmentation",
322
+ "description": "Snow leopard detection using Grounding DINO and segmentation using SAM HQ",
323
+ "config": {
324
+ "strategy": "gdino_sam",
325
+ "confidence_threshold": confidence_threshold,
326
+ "text_prompt": text_prompt,
327
+ "box_threshold": box_threshold,
328
+ "text_threshold": text_threshold,
329
+ "device": device,
330
+ },
331
+ "metrics": {
332
+ "num_predictions": len(predictions),
333
+ },
334
+ "data": {
335
+ "image_path": str(image_path),
336
+ "image_size": {"width": image_width, "height": image_height},
337
+ "predictions": predictions,
338
+ },
339
+ }
340
+
341
+
342
+ def run_segmentation_stage(
343
+ image_path: Path | str,
344
+ strategy: str = "yolo",
345
+ confidence_threshold: float = 0.5,
346
+ device: str | None = None,
347
+ # YOLO-specific parameters
348
+ yolo_model: YOLO | None = None,
349
+ # GDINO+SAM-specific parameters
350
+ gdino_processor: Any | None = None,
351
+ gdino_model: Any | None = None,
352
+ sam_predictor: SamPredictor | None = None,
353
+ text_prompt: str = "a snow leopard.",
354
+ box_threshold: float = 0.30,
355
+ text_threshold: float = 0.20,
356
+ ) -> dict:
357
+ """Run segmentation on query image using specified strategy.
358
+
359
+ This stage performs snow leopard detection and segmentation using either:
360
+ - YOLO: End-to-end learned segmentation
361
+ - GDINO+SAM: Zero-shot detection + prompted segmentation
362
+
363
+ Args:
364
+ image_path: Path to input image
365
+ strategy: Segmentation strategy ("yolo" or "gdino_sam")
366
+ confidence_threshold: Minimum confidence to keep predictions (default: 0.5)
367
+ device: Device to run on ('cpu', 'cuda', or None for auto-detect)
368
+ yolo_model: Pre-loaded YOLO model (required if strategy="yolo")
369
+ gdino_processor: Pre-loaded GDINO processor (required if strategy="gdino_sam")
370
+ gdino_model: Pre-loaded GDINO model (required if strategy="gdino_sam")
371
+ sam_predictor: Pre-loaded SAM predictor (required if strategy="gdino_sam")
372
+ text_prompt: Text prompt for GDINO (default: "a snow leopard.")
373
+ box_threshold: GDINO box confidence threshold (default: 0.30)
374
+ text_threshold: GDINO text matching threshold (default: 0.20)
375
+
376
+ Returns:
377
+ Stage dict with structure:
378
+ {
379
+ "stage_id": "segmentation",
380
+ "stage_name": str,
381
+ "description": str,
382
+ "config": {
383
+ "strategy": str,
384
+ "confidence_threshold": float,
385
+ "device": str,
386
+ ...
387
+ },
388
+ "metrics": {
389
+ "num_predictions": int
390
+ },
391
+ "data": {
392
+ "image_path": str,
393
+ "image_size": {"width": int, "height": int},
394
+ "predictions": [
395
+ {
396
+ "mask": np.ndarray (H×W, uint8),
397
+ "confidence": float,
398
+ "bbox_xywhn": {...},
399
+ "class_id": int,
400
+ "class_name": str,
401
+ # Optional (GDINO+SAM only)
402
+ "sam_score": float,
403
+ "gdino_score": float,
404
+ },
405
+ ...
406
+ ]
407
+ }
408
+ }
409
+
410
+ Raises:
411
+ ValueError: If strategy is invalid or required models are missing
412
+ FileNotFoundError: If image doesn't exist
413
+ RuntimeError: If inference fails
414
+ """
415
+ image_path = Path(image_path)
416
+
417
+ # Validate inputs
418
+ if not image_path.exists():
419
+ raise FileNotFoundError(f"Image not found: {image_path}")
420
+
421
+ if strategy not in ["yolo", "gdino_sam"]:
422
+ raise ValueError(f"Invalid strategy: {strategy}. Must be 'yolo' or 'gdino_sam'")
423
+
424
+ # Auto-detect device if not specified
425
+ device = get_device(device=device, verbose=True)
426
+
427
+ # Dispatch to appropriate implementation
428
+ if strategy == "yolo":
429
+ if yolo_model is None:
430
+ raise ValueError("yolo_model is required when strategy='yolo'")
431
+ return _run_yolo_segmentation(
432
+ model=yolo_model,
433
+ image_path=image_path,
434
+ confidence_threshold=confidence_threshold,
435
+ device=device,
436
+ )
437
+
438
+ elif strategy == "gdino_sam":
439
+ if gdino_processor is None or gdino_model is None or sam_predictor is None:
440
+ raise ValueError(
441
+ "gdino_processor, gdino_model, and sam_predictor are required when strategy='gdino_sam'"
442
+ )
443
+ return _run_gdino_sam_segmentation(
444
+ gdino_processor=gdino_processor,
445
+ gdino_model=gdino_model,
446
+ sam_predictor=sam_predictor,
447
+ image_path=image_path,
448
+ confidence_threshold=confidence_threshold,
449
+ text_prompt=text_prompt,
450
+ box_threshold=box_threshold,
451
+ text_threshold=text_threshold,
452
+ device=device,
453
+ )
src/snowleopard_reid/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for snow leopard re-identification.
2
+
3
+ This module provides common utilities used across the project.
4
+ """
5
+
6
+ import logging
7
+
8
+ import torch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def get_device(device: str | None = None, verbose: bool = True) -> str:
14
+ """Get the device to use for computation.
15
+
16
+ Auto-detects GPU if available, or uses CPU as fallback.
17
+ Optionally allows manual override.
18
+
19
+ Args:
20
+ device: Manual device override ('cpu', 'cuda', or None for auto-detect)
21
+ verbose: Whether to log device information
22
+
23
+ Returns:
24
+ Device string ('cuda' or 'cpu')
25
+
26
+ Examples:
27
+ >>> device = get_device() # Auto-detect
28
+ >>> device = get_device('cpu') # Force CPU
29
+ >>> device = get_device('cuda') # Force CUDA (will fail if not available)
30
+ """
31
+ if device is None:
32
+ # Auto-detect
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ if verbose:
36
+ if device == "cuda":
37
+ gpu_name = torch.cuda.get_device_name(0)
38
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (
39
+ 1024**3
40
+ )
41
+ logger.info(f"Using GPU: {gpu_name} ({gpu_memory:.1f} GB)")
42
+ else:
43
+ logger.info("Using CPU (no GPU available)")
44
+ else:
45
+ # Manual override
46
+ device = device.lower()
47
+ if device not in ["cpu", "cuda"]:
48
+ raise ValueError(f"Invalid device: {device}. Must be 'cpu' or 'cuda'")
49
+
50
+ if device == "cuda" and not torch.cuda.is_available():
51
+ raise RuntimeError(
52
+ "CUDA device requested but CUDA is not available. "
53
+ "Install CUDA-enabled PyTorch or use device='cpu'"
54
+ )
55
+
56
+ if verbose:
57
+ logger.info(f"Using device: {device}")
58
+
59
+ return device
src/snowleopard_reid/visualization.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for snow leopard re-identification.
2
+
3
+ This module provides functions for visualizing keypoints, matches, and other
4
+ pipeline outputs for debugging and presentation.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+
14
+ def draw_keypoints_overlay(
15
+ image_path: Path | str,
16
+ keypoints: np.ndarray,
17
+ max_keypoints: int = 500,
18
+ color: str = "blue",
19
+ ps: int = 10,
20
+ ) -> Image.Image:
21
+ """Draw keypoints overlaid on an image.
22
+
23
+ Args:
24
+ image_path: Path to image file
25
+ keypoints: Keypoints array of shape [N, 2] with (x, y) coordinates
26
+ max_keypoints: Maximum number of keypoints to display
27
+ color: Color name ('blue', 'red', 'green', etc.)
28
+ ps: Point size for keypoints
29
+
30
+ Returns:
31
+ PIL Image with keypoints drawn
32
+ """
33
+ # Load image
34
+ img = cv2.imread(str(image_path))
35
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+
37
+ # Color mapping
38
+ color_map = {
39
+ "blue": (0, 0, 255),
40
+ "red": (255, 0, 0),
41
+ "green": (0, 255, 0),
42
+ "yellow": (255, 255, 0),
43
+ "cyan": (0, 255, 255),
44
+ "magenta": (255, 0, 255),
45
+ }
46
+ color_rgb = color_map.get(color.lower(), (0, 0, 255))
47
+
48
+ # Draw keypoints (limit to max_keypoints)
49
+ n_keypoints = min(len(keypoints), max_keypoints)
50
+ for i in range(n_keypoints):
51
+ x, y = keypoints[i]
52
+ cv2.circle(img_rgb, (int(x), int(y)), ps // 2, color_rgb, -1)
53
+
54
+ return Image.fromarray(img_rgb)
55
+
56
+
57
+ def draw_matched_keypoints(
58
+ query_image_path: Path | str,
59
+ catalog_image_path: Path | str,
60
+ query_keypoints: np.ndarray,
61
+ catalog_keypoints: np.ndarray,
62
+ match_scores: np.ndarray,
63
+ max_matches: int = 100,
64
+ ) -> Image.Image:
65
+ """Draw matched keypoints side-by-side with connecting lines.
66
+
67
+ Args:
68
+ query_image_path: Path to query image
69
+ catalog_image_path: Path to catalog image
70
+ query_keypoints: Query keypoints [N, 2]
71
+ catalog_keypoints: Catalog keypoints [N, 2]
72
+ match_scores: Match confidence scores [N]
73
+ max_matches: Maximum number of matches to display
74
+
75
+ Returns:
76
+ PIL Image with side-by-side images and match lines
77
+ """
78
+ # Load images
79
+ query_img = cv2.imread(str(query_image_path))
80
+ catalog_img = cv2.imread(str(catalog_image_path))
81
+
82
+ # Convert to RGB
83
+ query_rgb = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
84
+ catalog_rgb = cv2.cvtColor(catalog_img, cv2.COLOR_BGR2RGB)
85
+
86
+ # Resize images to same height for side-by-side display
87
+ max_height = 800
88
+ query_h, query_w = query_rgb.shape[:2]
89
+ catalog_h, catalog_w = catalog_rgb.shape[:2]
90
+
91
+ # Calculate scaling factors
92
+ if query_h > max_height or catalog_h > max_height:
93
+ query_scale = max_height / query_h
94
+ catalog_scale = max_height / catalog_h
95
+ else:
96
+ query_scale = 1.0
97
+ catalog_scale = 1.0
98
+
99
+ # Resize images
100
+ new_query_h = int(query_h * query_scale)
101
+ new_query_w = int(query_w * query_scale)
102
+ new_catalog_h = int(catalog_h * catalog_scale)
103
+ new_catalog_w = int(catalog_w * catalog_scale)
104
+
105
+ query_resized = cv2.resize(query_rgb, (new_query_w, new_query_h))
106
+ catalog_resized = cv2.resize(catalog_rgb, (new_catalog_w, new_catalog_h))
107
+
108
+ # Scale keypoints
109
+ query_kpts_scaled = query_keypoints * query_scale
110
+ catalog_kpts_scaled = catalog_keypoints * catalog_scale
111
+
112
+ # Create side-by-side canvas
113
+ combined_h = max(new_query_h, new_catalog_h)
114
+ combined_w = new_query_w + new_catalog_w
115
+ canvas = np.zeros((combined_h, combined_w, 3), dtype=np.uint8)
116
+
117
+ # Place images on canvas
118
+ canvas[:new_query_h, :new_query_w] = query_resized
119
+ canvas[:new_catalog_h, new_query_w : new_query_w + new_catalog_w] = catalog_resized
120
+
121
+ # Offset catalog keypoints to account for horizontal placement
122
+ catalog_kpts_offset = catalog_kpts_scaled.copy()
123
+ catalog_kpts_offset[:, 0] += new_query_w
124
+
125
+ # Draw matches (limit to max_matches)
126
+ n_matches = min(len(query_kpts_scaled), max_matches)
127
+
128
+ # Sort by match scores (highest confidence first)
129
+ if len(match_scores) > 0:
130
+ sorted_indices = np.argsort(match_scores)[::-1][:n_matches]
131
+ else:
132
+ sorted_indices = np.arange(n_matches)
133
+
134
+ # Draw lines and keypoints
135
+ for idx in sorted_indices:
136
+ query_pt = tuple(query_kpts_scaled[idx].astype(int))
137
+ catalog_pt = tuple(catalog_kpts_offset[idx].astype(int))
138
+
139
+ # Color based on match score (green = high, yellow = medium, red = low)
140
+ score = match_scores[idx] if len(match_scores) > 0 else 0.5
141
+ if score > 0.8:
142
+ color = (0, 255, 0) # Green
143
+ elif score > 0.5:
144
+ color = (255, 255, 0) # Yellow
145
+ else:
146
+ color = (255, 0, 0) # Red
147
+
148
+ # Draw line
149
+ cv2.line(canvas, query_pt, catalog_pt, color, 1)
150
+
151
+ # Draw keypoints
152
+ cv2.circle(canvas, query_pt, 5, (255, 0, 0), -1)
153
+ cv2.circle(canvas, catalog_pt, 5, (0, 0, 255), -1)
154
+
155
+ return Image.fromarray(canvas)
156
+
157
+
158
+ def draw_side_by_side_comparison(
159
+ query_image_path: Path | str,
160
+ catalog_image_path: Path | str,
161
+ max_height: int = 800,
162
+ ) -> Image.Image:
163
+ """Draw query and catalog images side-by-side without keypoints or annotations.
164
+
165
+ This provides a clean visual comparison of the two images without the visual
166
+ clutter of feature matching overlays. Useful for assessing overall visual
167
+ similarity and spotting patterns like spots, scars, or markings.
168
+
169
+ Args:
170
+ query_image_path: Path to query image
171
+ catalog_image_path: Path to catalog/reference image
172
+ max_height: Maximum height for resizing (default: 800)
173
+
174
+ Returns:
175
+ PIL Image with side-by-side images (no annotations)
176
+ """
177
+ # Load images
178
+ query_img = cv2.imread(str(query_image_path))
179
+ catalog_img = cv2.imread(str(catalog_image_path))
180
+
181
+ # Convert to RGB
182
+ query_rgb = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
183
+ catalog_rgb = cv2.cvtColor(catalog_img, cv2.COLOR_BGR2RGB)
184
+
185
+ # Resize images to same height for side-by-side display
186
+ query_h, query_w = query_rgb.shape[:2]
187
+ catalog_h, catalog_w = catalog_rgb.shape[:2]
188
+
189
+ # Calculate scaling factors
190
+ if query_h > max_height or catalog_h > max_height:
191
+ query_scale = max_height / query_h
192
+ catalog_scale = max_height / catalog_h
193
+ else:
194
+ query_scale = 1.0
195
+ catalog_scale = 1.0
196
+
197
+ # Resize images
198
+ new_query_h = int(query_h * query_scale)
199
+ new_query_w = int(query_w * query_scale)
200
+ new_catalog_h = int(catalog_h * catalog_scale)
201
+ new_catalog_w = int(catalog_w * catalog_scale)
202
+
203
+ query_resized = cv2.resize(query_rgb, (new_query_w, new_query_h))
204
+ catalog_resized = cv2.resize(catalog_rgb, (new_catalog_w, new_catalog_h))
205
+
206
+ # Create side-by-side canvas
207
+ combined_h = max(new_query_h, new_catalog_h)
208
+ combined_w = new_query_w + new_catalog_w
209
+ canvas = np.zeros((combined_h, combined_w, 3), dtype=np.uint8)
210
+
211
+ # Place images on canvas (no keypoints or lines)
212
+ canvas[:new_query_h, :new_query_w] = query_resized
213
+ canvas[:new_catalog_h, new_query_w : new_query_w + new_catalog_w] = catalog_resized
214
+
215
+ return Image.fromarray(canvas)