feat: initial gradio app
Browse files- .gitattributes +9 -0
- .gitignore +60 -0
- Makefile +31 -0
- README.md +48 -5
- app.py +1608 -0
- data/cache.tar.gz +3 -0
- data/catalog.tar.gz +3 -0
- data/examples/07070305 Agim.JPG +3 -0
- data/examples/08190121 Karindas.JPG +3 -0
- data/examples/08190742 Ayima.jpg +3 -0
- data/examples/09150237 AIKA.JPG +3 -0
- data/examples/IMG_7189 Ayima.JPG +3 -0
- pyproject.toml +32 -0
- requirements.txt +24 -0
- scripts/create_archives.py +154 -0
- scripts/precompute_cache.py +515 -0
- src/snowleopard_reid/__init__.py +27 -0
- src/snowleopard_reid/cache.py +421 -0
- src/snowleopard_reid/catalog/__init__.py +25 -0
- src/snowleopard_reid/catalog/loader.py +379 -0
- src/snowleopard_reid/data_setup.py +102 -0
- src/snowleopard_reid/features/__init__.py +27 -0
- src/snowleopard_reid/features/extraction.py +388 -0
- src/snowleopard_reid/images/__init__.py +13 -0
- src/snowleopard_reid/images/processing.py +93 -0
- src/snowleopard_reid/masks/__init__.py +13 -0
- src/snowleopard_reid/masks/processing.py +99 -0
- src/snowleopard_reid/pipeline/__init__.py +23 -0
- src/snowleopard_reid/pipeline/stages/__init__.py +20 -0
- src/snowleopard_reid/pipeline/stages/feature_extraction.py +134 -0
- src/snowleopard_reid/pipeline/stages/mask_selection.py +153 -0
- src/snowleopard_reid/pipeline/stages/matching.py +593 -0
- src/snowleopard_reid/pipeline/stages/preprocess.py +142 -0
- src/snowleopard_reid/pipeline/stages/segmentation.py +453 -0
- src/snowleopard_reid/utils.py +59 -0
- src/snowleopard_reid/visualization.py +215 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Images (catalog and examples)
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.JPG filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.JPEG filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.PNG filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
|
| 9 |
+
# Archives and models
|
| 10 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 11 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.venv/
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
env/
|
| 28 |
+
|
| 29 |
+
# uv
|
| 30 |
+
uv.lock
|
| 31 |
+
|
| 32 |
+
# IDE
|
| 33 |
+
.idea/
|
| 34 |
+
.vscode/
|
| 35 |
+
*.swp
|
| 36 |
+
*.swo
|
| 37 |
+
*~
|
| 38 |
+
|
| 39 |
+
# Models (downloaded at runtime)
|
| 40 |
+
data/models/
|
| 41 |
+
|
| 42 |
+
# Extracted data (regenerated from archives on first run)
|
| 43 |
+
data/catalog/
|
| 44 |
+
cached_results/
|
| 45 |
+
|
| 46 |
+
# Keep archives tracked (these are the source of truth)
|
| 47 |
+
!data/catalog.tar.gz
|
| 48 |
+
!data/cache.tar.gz
|
| 49 |
+
|
| 50 |
+
# Temp files
|
| 51 |
+
*.tmp
|
| 52 |
+
*.temp
|
| 53 |
+
.DS_Store
|
| 54 |
+
Thumbs.db
|
| 55 |
+
|
| 56 |
+
# Jupyter
|
| 57 |
+
.ipynb_checkpoints/
|
| 58 |
+
|
| 59 |
+
# Logs
|
| 60 |
+
*.log
|
Makefile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: help install run clean create-archives extract-data precompute-cache archive-info
|
| 2 |
+
|
| 3 |
+
help: ## Show this help message
|
| 4 |
+
@echo "Usage: make [target]"
|
| 5 |
+
@echo ""
|
| 6 |
+
@echo "Available targets:"
|
| 7 |
+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
| 8 |
+
|
| 9 |
+
install: ## Install dependencies with uv
|
| 10 |
+
uv sync
|
| 11 |
+
|
| 12 |
+
run: ## Run the Gradio app locally
|
| 13 |
+
uv run python app.py
|
| 14 |
+
|
| 15 |
+
clean: ## Clean up temporary files and caches
|
| 16 |
+
rm -rf .venv uv.lock
|
| 17 |
+
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
| 18 |
+
find . -type f -name "*.pyc" -delete
|
| 19 |
+
find . -type f -name "*.pyo" -delete
|
| 20 |
+
|
| 21 |
+
create-archives: ## Create compressed archives from catalog and cache
|
| 22 |
+
uv run python scripts/create_archives.py
|
| 23 |
+
|
| 24 |
+
extract-data: ## Extract archives (done automatically on first run)
|
| 25 |
+
uv run python -c "from snowleopard_reid.data_setup import ensure_data_extracted; ensure_data_extracted()"
|
| 26 |
+
|
| 27 |
+
precompute-cache: ## Run pipeline on all examples to generate cache
|
| 28 |
+
uv run python scripts/precompute_cache.py
|
| 29 |
+
|
| 30 |
+
archive-info: ## Show info about archives and directories
|
| 31 |
+
uv run python scripts/create_archives.py --info
|
README.md
CHANGED
|
@@ -1,13 +1,56 @@
|
|
| 1 |
---
|
| 2 |
-
title: Snowleopard
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Snowleopard reID
|
| 3 |
+
emoji: 🐆
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
python_version: 3.11
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
short_description: Snow Leopard reID using AI
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Snow Leopard Re-Identification
|
| 15 |
+
|
| 16 |
+
AI-powered snow leopard re-identification system using computer vision and deep learning.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **Single Image Matching**: Upload an image to identify individual snow leopards against a catalog
|
| 21 |
+
- **Batch Processing**: Process multiple images at once with filtering options
|
| 22 |
+
- **Multiple Detection Methods**: YOLO or Grounding DINO + SAM HQ for segmentation
|
| 23 |
+
- **Multiple Matching Algorithms**: SIFT, SuperPoint, DISK, or ALIKED feature extractors
|
| 24 |
+
|
| 25 |
+
## Local Development
|
| 26 |
+
|
| 27 |
+
### Prerequisites
|
| 28 |
+
|
| 29 |
+
- Python 3.11+
|
| 30 |
+
- [uv](https://github.com/astral-sh/uv) package manager
|
| 31 |
+
|
| 32 |
+
### Setup
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
# Install dependencies
|
| 36 |
+
make install
|
| 37 |
+
|
| 38 |
+
# Run the app locally
|
| 39 |
+
make run
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
The app will be available at `http://localhost:7860`.
|
| 43 |
+
|
| 44 |
+
## Data
|
| 45 |
+
|
| 46 |
+
- **Catalog**: Pre-computed features for known snow leopards (stored with Git LFS)
|
| 47 |
+
- **SAM HQ Model**: Downloaded automatically at runtime from HuggingFace Hub
|
| 48 |
+
|
| 49 |
+
## Tech Stack
|
| 50 |
+
|
| 51 |
+
- Gradio for the web interface
|
| 52 |
+
- PyTorch for deep learning
|
| 53 |
+
- Grounding DINO for zero-shot object detection
|
| 54 |
+
- SAM HQ (Segment Anything Model High Quality) for segmentation
|
| 55 |
+
- LightGlue for feature matching
|
| 56 |
+
- Wasserstein distance for match scoring
|
app.py
ADDED
|
@@ -0,0 +1,1608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio web application for snow leopard identification and catalog exploration.
|
| 2 |
+
|
| 3 |
+
This interactive web interface provides an easy-to-use frontend for the snow
|
| 4 |
+
leopard identification system. Users can upload images, view matches against the catalog,
|
| 5 |
+
and explore reference leopards through a browser-based UI powered by Gradio.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Upload snow leopard images or select from examples
|
| 9 |
+
- Run full identification pipeline with GDINO+SAM segmentation
|
| 10 |
+
- View top-K matches with Wasserstein distance scores
|
| 11 |
+
- Explore complete leopard catalog with thumbnails
|
| 12 |
+
- Visualize matched keypoints between query and catalog images
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Local testing with uv:
|
| 16 |
+
uv sync
|
| 17 |
+
uv run python app.py
|
| 18 |
+
|
| 19 |
+
# Deployed on Hugging Face Spaces
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
# Add src to path for imports BEFORE importing snowleopard_reid
|
| 26 |
+
SPACE_ROOT = Path(__file__).parent
|
| 27 |
+
sys.path.insert(0, str(SPACE_ROOT / "src"))
|
| 28 |
+
|
| 29 |
+
import logging
|
| 30 |
+
import shutil
|
| 31 |
+
import tempfile
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
|
| 34 |
+
import cv2
|
| 35 |
+
import gradio as gr
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import yaml
|
| 39 |
+
from huggingface_hub import hf_hub_download
|
| 40 |
+
from PIL import Image
|
| 41 |
+
|
| 42 |
+
from snowleopard_reid.catalog import (
|
| 43 |
+
get_available_body_parts,
|
| 44 |
+
get_available_locations,
|
| 45 |
+
get_catalog_metadata_for_id,
|
| 46 |
+
load_catalog_index,
|
| 47 |
+
load_leopard_metadata,
|
| 48 |
+
)
|
| 49 |
+
from snowleopard_reid.pipeline.stages import (
|
| 50 |
+
run_feature_extraction_stage,
|
| 51 |
+
run_matching_stage,
|
| 52 |
+
run_preprocess_stage,
|
| 53 |
+
run_segmentation_stage,
|
| 54 |
+
select_best_mask,
|
| 55 |
+
)
|
| 56 |
+
from snowleopard_reid.pipeline.stages.segmentation import (
|
| 57 |
+
load_gdino_model,
|
| 58 |
+
load_sam_predictor,
|
| 59 |
+
)
|
| 60 |
+
from snowleopard_reid.visualization import (
|
| 61 |
+
draw_keypoints_overlay,
|
| 62 |
+
draw_matched_keypoints,
|
| 63 |
+
draw_side_by_side_comparison,
|
| 64 |
+
)
|
| 65 |
+
from snowleopard_reid.cache import (
|
| 66 |
+
filter_cached_matches,
|
| 67 |
+
generate_visualizations_from_npz,
|
| 68 |
+
is_cached,
|
| 69 |
+
load_cached_results,
|
| 70 |
+
)
|
| 71 |
+
from snowleopard_reid.data_setup import ensure_data_extracted
|
| 72 |
+
|
| 73 |
+
# Configure logging
|
| 74 |
+
logging.basicConfig(
|
| 75 |
+
level=logging.INFO,
|
| 76 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 77 |
+
)
|
| 78 |
+
logger = logging.getLogger(__name__)
|
| 79 |
+
|
| 80 |
+
# Configuration (hardcoded for HF Spaces / local dev)
|
| 81 |
+
CATALOG_ROOT = SPACE_ROOT / "data" / "catalog"
|
| 82 |
+
SAM_CHECKPOINT_DIR = SPACE_ROOT / "data" / "models"
|
| 83 |
+
SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth"
|
| 84 |
+
EXAMPLES_DIR = SPACE_ROOT / "data" / "examples"
|
| 85 |
+
GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
|
| 86 |
+
TEXT_PROMPT = "a snow leopard."
|
| 87 |
+
TOP_K_DEFAULT = 5
|
| 88 |
+
SAM_MODEL_TYPE = "vit_l"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class AppConfig:
|
| 93 |
+
"""Configuration for the Snow Leopard ID UI application."""
|
| 94 |
+
|
| 95 |
+
model_path: Path | None
|
| 96 |
+
catalog_root: Path
|
| 97 |
+
examples_dir: Path
|
| 98 |
+
top_k: int
|
| 99 |
+
port: int
|
| 100 |
+
share: bool
|
| 101 |
+
# GDINO+SAM parameters
|
| 102 |
+
sam_checkpoint_path: Path
|
| 103 |
+
sam_model_type: str
|
| 104 |
+
gdino_model_id: str
|
| 105 |
+
text_prompt: str
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def ensure_sam_model() -> Path:
|
| 109 |
+
"""Download SAM HQ model if not present.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Path to the SAM HQ checkpoint file
|
| 113 |
+
"""
|
| 114 |
+
sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME
|
| 115 |
+
if not sam_path.exists():
|
| 116 |
+
logger.info("Downloading SAM HQ model (1.6GB)...")
|
| 117 |
+
SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
hf_hub_download(
|
| 119 |
+
repo_id="lkeab/hq-sam",
|
| 120 |
+
filename=SAM_CHECKPOINT_NAME,
|
| 121 |
+
local_dir=SAM_CHECKPOINT_DIR,
|
| 122 |
+
)
|
| 123 |
+
logger.info("SAM HQ model downloaded successfully")
|
| 124 |
+
return sam_path
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_available_extractors(catalog_root: Path) -> list[str]:
|
| 128 |
+
"""Get list of available feature extractors from catalog.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
catalog_root: Root directory of the leopard catalog
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List of available extractor names (e.g., ['sift', 'superpoint'])
|
| 135 |
+
"""
|
| 136 |
+
try:
|
| 137 |
+
catalog_index = load_catalog_index(catalog_root)
|
| 138 |
+
extractors = list(catalog_index.get("feature_extractors", {}).keys())
|
| 139 |
+
if not extractors:
|
| 140 |
+
logger.warning(f"No extractors found in catalog at {catalog_root}")
|
| 141 |
+
return ["sift"] # Default fallback
|
| 142 |
+
return extractors
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Failed to load catalog index: {e}")
|
| 145 |
+
return ["sift"] # Default fallback
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Global state for models and catalog (loaded at startup)
|
| 149 |
+
LOADED_MODELS = {}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_catalog_data(config: AppConfig):
|
| 153 |
+
"""Load catalog index and individual leopard metadata.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
config: Application configuration containing catalog_root
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Tuple of (catalog_index, individuals_data)
|
| 160 |
+
"""
|
| 161 |
+
catalog_index_path = config.catalog_root / "catalog_index.yaml"
|
| 162 |
+
|
| 163 |
+
# Load catalog index
|
| 164 |
+
with open(catalog_index_path) as f:
|
| 165 |
+
catalog_index = yaml.safe_load(f)
|
| 166 |
+
|
| 167 |
+
# Load metadata for each individual
|
| 168 |
+
individuals_data = []
|
| 169 |
+
for individual in catalog_index["individuals"]:
|
| 170 |
+
metadata_path = config.catalog_root / individual["metadata_path"]
|
| 171 |
+
with open(metadata_path) as f:
|
| 172 |
+
leopard_metadata = yaml.safe_load(f)
|
| 173 |
+
individuals_data.append(leopard_metadata)
|
| 174 |
+
|
| 175 |
+
return catalog_index, individuals_data
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def initialize_models(config: AppConfig):
|
| 179 |
+
"""Load models at startup for faster inference.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
config: Application configuration containing model paths
|
| 183 |
+
"""
|
| 184 |
+
logger.info("Initializing models...")
|
| 185 |
+
|
| 186 |
+
# Check for GPU
|
| 187 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 188 |
+
logger.info(f"Using device: {device}")
|
| 189 |
+
|
| 190 |
+
if device == "cuda":
|
| 191 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 192 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 193 |
+
logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")
|
| 194 |
+
|
| 195 |
+
# Load Grounding DINO model
|
| 196 |
+
logger.info(f"Loading Grounding DINO model: {config.gdino_model_id}")
|
| 197 |
+
gdino_processor, gdino_model = load_gdino_model(
|
| 198 |
+
model_id=config.gdino_model_id,
|
| 199 |
+
device=device,
|
| 200 |
+
)
|
| 201 |
+
LOADED_MODELS["gdino_processor"] = gdino_processor
|
| 202 |
+
LOADED_MODELS["gdino_model"] = gdino_model
|
| 203 |
+
logger.info("Grounding DINO model loaded successfully")
|
| 204 |
+
|
| 205 |
+
# Load SAM HQ model
|
| 206 |
+
logger.info(
|
| 207 |
+
f"Loading SAM HQ model from {config.sam_checkpoint_path} (type: {config.sam_model_type})"
|
| 208 |
+
)
|
| 209 |
+
sam_predictor = load_sam_predictor(
|
| 210 |
+
checkpoint_path=config.sam_checkpoint_path,
|
| 211 |
+
model_type=config.sam_model_type,
|
| 212 |
+
device=device,
|
| 213 |
+
)
|
| 214 |
+
LOADED_MODELS["sam_predictor"] = sam_predictor
|
| 215 |
+
logger.info("SAM HQ model loaded successfully")
|
| 216 |
+
|
| 217 |
+
# Store device info and catalog root for callbacks
|
| 218 |
+
LOADED_MODELS["device"] = device
|
| 219 |
+
LOADED_MODELS["catalog_root"] = config.catalog_root
|
| 220 |
+
LOADED_MODELS["text_prompt"] = config.text_prompt
|
| 221 |
+
|
| 222 |
+
logger.info("Models initialized successfully")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def _load_from_cache(
|
| 226 |
+
example_path: str,
|
| 227 |
+
extractor: str,
|
| 228 |
+
config: "AppConfig",
|
| 229 |
+
filter_locations: list[str] | None = None,
|
| 230 |
+
filter_body_parts: list[str] | None = None,
|
| 231 |
+
top_k: int = 5,
|
| 232 |
+
):
|
| 233 |
+
"""Load cached pipeline results with optional filtering and return UI component updates.
|
| 234 |
+
|
| 235 |
+
Supports the v2.0 cache format which stores ALL matches with location/body_part
|
| 236 |
+
metadata, enabling client-side filtering without re-running the pipeline.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
example_path: Path to the example image
|
| 240 |
+
extractor: Feature extractor name
|
| 241 |
+
config: Application configuration
|
| 242 |
+
filter_locations: Optional list of locations to filter by
|
| 243 |
+
filter_body_parts: Optional list of body parts to filter by
|
| 244 |
+
top_k: Number of top matches to return after filtering
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Tuple of 23 UI components matching run_identification output
|
| 248 |
+
"""
|
| 249 |
+
# Load cached results
|
| 250 |
+
cached = load_cached_results(example_path, extractor)
|
| 251 |
+
predictions = cached["predictions"]
|
| 252 |
+
|
| 253 |
+
# Support both v1.0 ("matches") and v2.0 ("all_matches") cache formats
|
| 254 |
+
if "all_matches" in predictions:
|
| 255 |
+
all_matches = predictions["all_matches"]
|
| 256 |
+
else:
|
| 257 |
+
# Fallback for v1.0 cache format (no filtering support)
|
| 258 |
+
all_matches = predictions.get("matches", [])
|
| 259 |
+
|
| 260 |
+
# Filter and re-rank matches
|
| 261 |
+
matches = filter_cached_matches(
|
| 262 |
+
all_matches=all_matches,
|
| 263 |
+
filter_locations=filter_locations,
|
| 264 |
+
filter_body_parts=filter_body_parts,
|
| 265 |
+
top_k=top_k,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if not matches:
|
| 269 |
+
# No matches after filtering - return empty results
|
| 270 |
+
return (
|
| 271 |
+
"No matches found with the selected filters",
|
| 272 |
+
cached["segmentation_image"],
|
| 273 |
+
cached["cropped_image"],
|
| 274 |
+
cached["keypoints_image"],
|
| 275 |
+
[],
|
| 276 |
+
gr.update(value=None),
|
| 277 |
+
gr.update(value=None),
|
| 278 |
+
gr.update(value=""),
|
| 279 |
+
gr.update(value=""),
|
| 280 |
+
gr.update(value=""),
|
| 281 |
+
gr.update(value=""),
|
| 282 |
+
gr.update(value=""),
|
| 283 |
+
gr.update(value=""),
|
| 284 |
+
gr.update(visible=False),
|
| 285 |
+
gr.update(visible=False),
|
| 286 |
+
gr.update(visible=False),
|
| 287 |
+
gr.update(visible=False),
|
| 288 |
+
gr.update(visible=False),
|
| 289 |
+
gr.update(value=[]),
|
| 290 |
+
gr.update(value=[]),
|
| 291 |
+
gr.update(value=[]),
|
| 292 |
+
gr.update(value=[]),
|
| 293 |
+
gr.update(value=[]),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Generate visualizations on-demand from NPZ data
|
| 297 |
+
logger.info(f"Generating visualizations for {len(matches)} filtered matches...")
|
| 298 |
+
match_visualizations, clean_comparison_visualizations = (
|
| 299 |
+
generate_visualizations_from_npz(
|
| 300 |
+
pairwise_dir=cached["pairwise_dir"],
|
| 301 |
+
matches=matches,
|
| 302 |
+
cropped_image_path=cached["pairwise_dir"].parent / "cropped.png",
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Store in global state for match selection
|
| 307 |
+
LOADED_MODELS["current_match_visualizations"] = match_visualizations
|
| 308 |
+
LOADED_MODELS["current_clean_comparison_visualizations"] = (
|
| 309 |
+
clean_comparison_visualizations
|
| 310 |
+
)
|
| 311 |
+
LOADED_MODELS["current_enriched_matches"] = matches
|
| 312 |
+
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
|
| 313 |
+
LOADED_MODELS["current_temp_dir"] = None # No temp dir for cached results
|
| 314 |
+
|
| 315 |
+
# Top match info for result text
|
| 316 |
+
top_match = matches[0]
|
| 317 |
+
top_leopard_name = top_match["leopard_name"]
|
| 318 |
+
top_wasserstein = top_match["wasserstein"]
|
| 319 |
+
|
| 320 |
+
# Determine confidence level
|
| 321 |
+
if top_wasserstein >= 0.12:
|
| 322 |
+
confidence_indicator = "🔵" # Excellent
|
| 323 |
+
elif top_wasserstein >= 0.07:
|
| 324 |
+
confidence_indicator = "🟢" # Good
|
| 325 |
+
elif top_wasserstein >= 0.04:
|
| 326 |
+
confidence_indicator = "🟡" # Fair
|
| 327 |
+
else:
|
| 328 |
+
confidence_indicator = "🔴" # Uncertain
|
| 329 |
+
|
| 330 |
+
result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
|
| 331 |
+
|
| 332 |
+
# Build dataset for top-K matches table
|
| 333 |
+
dataset_samples = []
|
| 334 |
+
for match in matches:
|
| 335 |
+
rank = match["rank"]
|
| 336 |
+
leopard_name = match["leopard_name"]
|
| 337 |
+
wasserstein = match["wasserstein"]
|
| 338 |
+
|
| 339 |
+
# Use location from cache (v2.0) or extract from path
|
| 340 |
+
location = match.get("location", "unknown")
|
| 341 |
+
if location == "unknown":
|
| 342 |
+
catalog_id = match["catalog_id"]
|
| 343 |
+
catalog_metadata = get_catalog_metadata_for_id(
|
| 344 |
+
config.catalog_root, catalog_id
|
| 345 |
+
)
|
| 346 |
+
if catalog_metadata:
|
| 347 |
+
img_path_parts = Path(catalog_metadata["image_path"]).parts
|
| 348 |
+
try:
|
| 349 |
+
db_idx = img_path_parts.index("database")
|
| 350 |
+
if db_idx + 1 < len(img_path_parts):
|
| 351 |
+
location = img_path_parts[db_idx + 1]
|
| 352 |
+
except ValueError:
|
| 353 |
+
pass
|
| 354 |
+
|
| 355 |
+
# Confidence indicator
|
| 356 |
+
if wasserstein >= 0.12:
|
| 357 |
+
indicator = "🔵"
|
| 358 |
+
elif wasserstein >= 0.07:
|
| 359 |
+
indicator = "🟢"
|
| 360 |
+
elif wasserstein >= 0.04:
|
| 361 |
+
indicator = "🟡"
|
| 362 |
+
else:
|
| 363 |
+
indicator = "🔴"
|
| 364 |
+
|
| 365 |
+
dataset_samples.append(
|
| 366 |
+
[
|
| 367 |
+
rank,
|
| 368 |
+
indicator,
|
| 369 |
+
leopard_name.title(),
|
| 370 |
+
location.replace("_", " ").title(),
|
| 371 |
+
f"{wasserstein:.4f}",
|
| 372 |
+
]
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Load rank 1 details
|
| 376 |
+
rank1_details = load_match_details_for_rank(rank=1)
|
| 377 |
+
|
| 378 |
+
# Return all 23 outputs
|
| 379 |
+
return (
|
| 380 |
+
result_text, # 1. Top match result text
|
| 381 |
+
cached["segmentation_image"], # 2. Segmentation overlay
|
| 382 |
+
cached["cropped_image"], # 3. Cropped leopard
|
| 383 |
+
cached["keypoints_image"], # 4. Extracted keypoints
|
| 384 |
+
dataset_samples, # 5. Matches table data
|
| 385 |
+
*rank1_details, # 6-23. visualizations, header, indicators, galleries
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def run_identification(
|
| 390 |
+
image,
|
| 391 |
+
extractor: str,
|
| 392 |
+
top_k: int,
|
| 393 |
+
selected_locations: list[str],
|
| 394 |
+
selected_body_parts: list[str],
|
| 395 |
+
example_path: str | None,
|
| 396 |
+
config: AppConfig,
|
| 397 |
+
):
|
| 398 |
+
"""Run snow leopard identification pipeline on uploaded image.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
image: PIL Image from Gradio upload
|
| 402 |
+
extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked')
|
| 403 |
+
top_k: Number of top matches to return
|
| 404 |
+
selected_locations: List of selected locations (includes "all" for no filtering)
|
| 405 |
+
selected_body_parts: List of selected body parts (includes "all" for no filtering)
|
| 406 |
+
example_path: Path to example image if selected from examples (for cache lookup)
|
| 407 |
+
config: Application configuration
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Tuple of UI components to update
|
| 411 |
+
"""
|
| 412 |
+
if image is None:
|
| 413 |
+
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
|
| 414 |
+
return (
|
| 415 |
+
"Please upload an image first", # 1. result_text
|
| 416 |
+
None, # 2. seg_viz
|
| 417 |
+
None, # 3. cropped_image
|
| 418 |
+
None, # 4. extracted_kpts_viz
|
| 419 |
+
[], # 5. dataset_samples
|
| 420 |
+
gr.update(value=None), # 6. matched_kpts_viz
|
| 421 |
+
gr.update(value=None), # 7. clean_comparison_viz
|
| 422 |
+
gr.update(value=""), # 8. header
|
| 423 |
+
gr.update(value=""), # 9. head indicator
|
| 424 |
+
gr.update(value=""), # 10. left_flank indicator
|
| 425 |
+
gr.update(value=""), # 11. right_flank indicator
|
| 426 |
+
gr.update(value=""), # 12. tail indicator
|
| 427 |
+
gr.update(value=""), # 13. misc indicator
|
| 428 |
+
gr.update(visible=False), # 14. head empty message
|
| 429 |
+
gr.update(visible=False), # 15. left_flank empty message
|
| 430 |
+
gr.update(visible=False), # 16. right_flank empty message
|
| 431 |
+
gr.update(visible=False), # 17. tail empty message
|
| 432 |
+
gr.update(visible=False), # 18. misc empty message
|
| 433 |
+
gr.update(value=[]), # 19. head gallery
|
| 434 |
+
gr.update(value=[]), # 20. left_flank gallery
|
| 435 |
+
gr.update(value=[]), # 21. right_flank gallery
|
| 436 |
+
gr.update(value=[]), # 22. tail gallery
|
| 437 |
+
gr.update(value=[]), # 23. misc gallery
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Convert filter selections to None if "all" is selected
|
| 441 |
+
filter_locations = (
|
| 442 |
+
None
|
| 443 |
+
if not selected_locations or "all" in selected_locations
|
| 444 |
+
else selected_locations
|
| 445 |
+
)
|
| 446 |
+
filter_body_parts_parsed = (
|
| 447 |
+
None
|
| 448 |
+
if not selected_body_parts or "all" in selected_body_parts
|
| 449 |
+
else selected_body_parts
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Check cache for example images (v2.0 cache supports filtering)
|
| 453 |
+
if example_path and is_cached(example_path, extractor):
|
| 454 |
+
logger.info(f"Cache hit for {example_path} with {extractor}")
|
| 455 |
+
if filter_locations or filter_body_parts_parsed:
|
| 456 |
+
logger.info(f" Applying filters: locations={filter_locations}, body_parts={filter_body_parts_parsed}")
|
| 457 |
+
try:
|
| 458 |
+
return _load_from_cache(
|
| 459 |
+
example_path,
|
| 460 |
+
extractor,
|
| 461 |
+
config,
|
| 462 |
+
filter_locations=filter_locations,
|
| 463 |
+
filter_body_parts=filter_body_parts_parsed,
|
| 464 |
+
top_k=int(top_k),
|
| 465 |
+
)
|
| 466 |
+
except Exception as e:
|
| 467 |
+
logger.warning(f"Cache load failed, running pipeline: {e}")
|
| 468 |
+
# Fall through to run full pipeline
|
| 469 |
+
|
| 470 |
+
# Use the already-parsed filter values for the pipeline
|
| 471 |
+
filter_body_parts = filter_body_parts_parsed
|
| 472 |
+
|
| 473 |
+
# Log applied filters
|
| 474 |
+
if filter_locations or filter_body_parts:
|
| 475 |
+
filter_desc = []
|
| 476 |
+
if filter_locations:
|
| 477 |
+
filter_desc.append(f"locations: {', '.join(filter_locations)}")
|
| 478 |
+
if filter_body_parts:
|
| 479 |
+
filter_desc.append(f"body parts: {', '.join(filter_body_parts)}")
|
| 480 |
+
logger.info(f"Applied filters - {' | '.join(filter_desc)}")
|
| 481 |
+
else:
|
| 482 |
+
logger.info("No filters applied - matching against entire catalog")
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
# Create temporary directory for this query
|
| 486 |
+
temp_dir = Path(tempfile.mkdtemp(prefix="snowleopard_id_"))
|
| 487 |
+
temp_image_path = temp_dir / "query.jpg"
|
| 488 |
+
|
| 489 |
+
# Save uploaded image
|
| 490 |
+
logger.info(f"Image type: {type(image)}")
|
| 491 |
+
logger.info(f"Image mode: {image.mode if hasattr(image, 'mode') else 'N/A'}")
|
| 492 |
+
logger.info(f"Image size: {image.size if hasattr(image, 'size') else 'N/A'}")
|
| 493 |
+
image.save(temp_image_path, quality=95)
|
| 494 |
+
|
| 495 |
+
# Verify saved image
|
| 496 |
+
saved_size = temp_image_path.stat().st_size
|
| 497 |
+
logger.info(f"Saved image size: {saved_size / 1024 / 1024:.2f} MB")
|
| 498 |
+
|
| 499 |
+
logger.info(f"Processing query image: {temp_image_path}")
|
| 500 |
+
|
| 501 |
+
device = LOADED_MODELS.get("device", "cpu")
|
| 502 |
+
|
| 503 |
+
# Step 1: Run GDINO+SAM segmentation using pre-loaded models
|
| 504 |
+
logger.info("Running GDINO+SAM segmentation...")
|
| 505 |
+
gdino_processor = LOADED_MODELS.get("gdino_processor")
|
| 506 |
+
gdino_model = LOADED_MODELS.get("gdino_model")
|
| 507 |
+
sam_predictor = LOADED_MODELS.get("sam_predictor")
|
| 508 |
+
text_prompt = LOADED_MODELS.get("text_prompt", "a snow leopard.")
|
| 509 |
+
|
| 510 |
+
seg_stage = run_segmentation_stage(
|
| 511 |
+
image_path=temp_image_path,
|
| 512 |
+
strategy="gdino_sam",
|
| 513 |
+
confidence_threshold=0.2,
|
| 514 |
+
device=device,
|
| 515 |
+
gdino_processor=gdino_processor,
|
| 516 |
+
gdino_model=gdino_model,
|
| 517 |
+
sam_predictor=sam_predictor,
|
| 518 |
+
text_prompt=text_prompt,
|
| 519 |
+
box_threshold=0.30,
|
| 520 |
+
text_threshold=0.20,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
predictions = seg_stage["data"]["predictions"]
|
| 524 |
+
logger.info(f"Number of predictions: {len(predictions)}")
|
| 525 |
+
|
| 526 |
+
if not predictions:
|
| 527 |
+
logger.warning("No predictions found from segmentation")
|
| 528 |
+
logger.warning(f"Full segmentation stage: {seg_stage}")
|
| 529 |
+
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
|
| 530 |
+
return (
|
| 531 |
+
"No snow leopards detected in image", # 1. result_text
|
| 532 |
+
None, # 2. seg_viz
|
| 533 |
+
None, # 3. cropped_image
|
| 534 |
+
None, # 4. extracted_kpts_viz
|
| 535 |
+
[], # 5. dataset_samples
|
| 536 |
+
gr.update(value=None), # 6. matched_kpts_viz
|
| 537 |
+
gr.update(value=None), # 7. clean_comparison_viz
|
| 538 |
+
gr.update(value=""), # 8. header
|
| 539 |
+
gr.update(value=""), # 9. head indicator
|
| 540 |
+
gr.update(value=""), # 10. left_flank indicator
|
| 541 |
+
gr.update(value=""), # 11. right_flank indicator
|
| 542 |
+
gr.update(value=""), # 12. tail indicator
|
| 543 |
+
gr.update(value=""), # 13. misc indicator
|
| 544 |
+
gr.update(visible=False), # 14. head empty message
|
| 545 |
+
gr.update(visible=False), # 15. left_flank empty message
|
| 546 |
+
gr.update(visible=False), # 16. right_flank empty message
|
| 547 |
+
gr.update(visible=False), # 17. tail empty message
|
| 548 |
+
gr.update(visible=False), # 18. misc empty message
|
| 549 |
+
gr.update(value=[]), # 19. head gallery
|
| 550 |
+
gr.update(value=[]), # 20. left_flank gallery
|
| 551 |
+
gr.update(value=[]), # 21. right_flank gallery
|
| 552 |
+
gr.update(value=[]), # 22. tail gallery
|
| 553 |
+
gr.update(value=[]), # 23. misc gallery
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Step 2: Select best mask
|
| 557 |
+
logger.info("Selecting best mask...")
|
| 558 |
+
selected_idx, selected_pred = select_best_mask(
|
| 559 |
+
predictions,
|
| 560 |
+
strategy="confidence_area",
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Step 3: Preprocess (crop and mask)
|
| 564 |
+
logger.info("Preprocessing query image...")
|
| 565 |
+
prep_stage = run_preprocess_stage(
|
| 566 |
+
image_path=temp_image_path,
|
| 567 |
+
mask=selected_pred["mask"],
|
| 568 |
+
padding=5,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
cropped_image_pil = prep_stage["data"]["cropped_image"]
|
| 572 |
+
|
| 573 |
+
# Save cropped image for visualization later
|
| 574 |
+
cropped_path = temp_dir / "cropped.jpg"
|
| 575 |
+
cropped_image_pil.save(cropped_path)
|
| 576 |
+
|
| 577 |
+
# Step 4: Extract features
|
| 578 |
+
logger.info(f"Extracting features using {extractor.upper()}...")
|
| 579 |
+
feat_stage = run_feature_extraction_stage(
|
| 580 |
+
image=cropped_image_pil,
|
| 581 |
+
extractor=extractor,
|
| 582 |
+
max_keypoints=2048,
|
| 583 |
+
device=device,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
query_features = feat_stage["data"]["features"]
|
| 587 |
+
|
| 588 |
+
# Step 5: Match against catalog
|
| 589 |
+
logger.info("Matching against catalog...")
|
| 590 |
+
pairwise_dir = temp_dir / "pairwise"
|
| 591 |
+
pairwise_dir.mkdir(exist_ok=True)
|
| 592 |
+
|
| 593 |
+
match_stage = run_matching_stage(
|
| 594 |
+
query_features=query_features,
|
| 595 |
+
catalog_path=config.catalog_root,
|
| 596 |
+
top_k=top_k,
|
| 597 |
+
extractor=extractor,
|
| 598 |
+
device=device,
|
| 599 |
+
query_image_path=str(cropped_path),
|
| 600 |
+
pairwise_output_dir=pairwise_dir,
|
| 601 |
+
filter_locations=filter_locations,
|
| 602 |
+
filter_body_parts=filter_body_parts,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
matches = match_stage["data"]["matches"]
|
| 606 |
+
|
| 607 |
+
if not matches:
|
| 608 |
+
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
|
| 609 |
+
return (
|
| 610 |
+
"No matches found in catalog", # 1. result_text
|
| 611 |
+
None, # 2. seg_viz
|
| 612 |
+
cropped_image_pil, # 3. cropped_image
|
| 613 |
+
None, # 4. extracted_kpts_viz
|
| 614 |
+
[], # 5. dataset_samples
|
| 615 |
+
gr.update(value=None), # 6. matched_kpts_viz
|
| 616 |
+
gr.update(value=None), # 7. clean_comparison_viz
|
| 617 |
+
gr.update(value=""), # 8. header
|
| 618 |
+
gr.update(value=""), # 9. head indicator
|
| 619 |
+
gr.update(value=""), # 10. left_flank indicator
|
| 620 |
+
gr.update(value=""), # 11. right_flank indicator
|
| 621 |
+
gr.update(value=""), # 12. tail indicator
|
| 622 |
+
gr.update(value=""), # 13. misc indicator
|
| 623 |
+
gr.update(visible=False), # 14. head empty message
|
| 624 |
+
gr.update(visible=False), # 15. left_flank empty message
|
| 625 |
+
gr.update(visible=False), # 16. right_flank empty message
|
| 626 |
+
gr.update(visible=False), # 17. tail empty message
|
| 627 |
+
gr.update(visible=False), # 18. misc empty message
|
| 628 |
+
gr.update(value=[]), # 19. head gallery
|
| 629 |
+
gr.update(value=[]), # 20. left_flank gallery
|
| 630 |
+
gr.update(value=[]), # 21. right_flank gallery
|
| 631 |
+
gr.update(value=[]), # 22. tail gallery
|
| 632 |
+
gr.update(value=[]), # 23. misc gallery
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# Top match
|
| 636 |
+
top_match = matches[0]
|
| 637 |
+
top_leopard_name = top_match["leopard_name"]
|
| 638 |
+
top_wasserstein = top_match["wasserstein"]
|
| 639 |
+
|
| 640 |
+
# Determine confidence level (higher Wasserstein = better match)
|
| 641 |
+
if top_wasserstein >= 0.12:
|
| 642 |
+
confidence_indicator = "🔵" # Excellent
|
| 643 |
+
elif top_wasserstein >= 0.07:
|
| 644 |
+
confidence_indicator = "🟢" # Good
|
| 645 |
+
elif top_wasserstein >= 0.04:
|
| 646 |
+
confidence_indicator = "🟡" # Fair
|
| 647 |
+
else:
|
| 648 |
+
confidence_indicator = "🔴" # Uncertain
|
| 649 |
+
|
| 650 |
+
result_text = f"## {confidence_indicator} {top_leopard_name.title()}"
|
| 651 |
+
|
| 652 |
+
# Create segmentation visualization
|
| 653 |
+
seg_viz = create_segmentation_viz(
|
| 654 |
+
image_path=temp_image_path, mask=selected_pred["mask"]
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Generate extracted keypoints visualization
|
| 658 |
+
extracted_kpts_viz = None
|
| 659 |
+
try:
|
| 660 |
+
# Extract keypoints from query features for visualization
|
| 661 |
+
query_kpts = query_features["keypoints"].cpu().numpy()
|
| 662 |
+
extracted_kpts_viz = draw_keypoints_overlay(
|
| 663 |
+
image_path=cropped_path,
|
| 664 |
+
keypoints=query_kpts,
|
| 665 |
+
max_keypoints=500,
|
| 666 |
+
color="blue",
|
| 667 |
+
ps=10,
|
| 668 |
+
)
|
| 669 |
+
except Exception as e:
|
| 670 |
+
logger.error(f"Error creating extracted keypoints visualization: {e}")
|
| 671 |
+
|
| 672 |
+
# Build dataset for top-K matches table
|
| 673 |
+
dataset_samples = []
|
| 674 |
+
match_visualizations = {}
|
| 675 |
+
clean_comparison_visualizations = {}
|
| 676 |
+
|
| 677 |
+
for match in matches:
|
| 678 |
+
rank = match["rank"]
|
| 679 |
+
leopard_name = match["leopard_name"]
|
| 680 |
+
wasserstein = match["wasserstein"]
|
| 681 |
+
catalog_img_path = Path(match["filepath"])
|
| 682 |
+
|
| 683 |
+
# Get location from catalog metadata
|
| 684 |
+
catalog_id = match["catalog_id"]
|
| 685 |
+
catalog_metadata = get_catalog_metadata_for_id(
|
| 686 |
+
config.catalog_root, catalog_id
|
| 687 |
+
)
|
| 688 |
+
location = "unknown"
|
| 689 |
+
if catalog_metadata:
|
| 690 |
+
# Extract location from path: database/{location}/{individual}/...
|
| 691 |
+
img_path_parts = Path(catalog_metadata["image_path"]).parts
|
| 692 |
+
if len(img_path_parts) >= 3:
|
| 693 |
+
# Find 'database' in path and get next part
|
| 694 |
+
try:
|
| 695 |
+
db_idx = img_path_parts.index("database")
|
| 696 |
+
if db_idx + 1 < len(img_path_parts):
|
| 697 |
+
location = img_path_parts[db_idx + 1]
|
| 698 |
+
except ValueError:
|
| 699 |
+
pass
|
| 700 |
+
|
| 701 |
+
# Confidence indicator (higher Wasserstein = better match)
|
| 702 |
+
if wasserstein >= 0.12:
|
| 703 |
+
indicator = "🔵" # Excellent
|
| 704 |
+
elif wasserstein >= 0.07:
|
| 705 |
+
indicator = "🟢" # Good
|
| 706 |
+
elif wasserstein >= 0.04:
|
| 707 |
+
indicator = "🟡" # Fair
|
| 708 |
+
else:
|
| 709 |
+
indicator = "🔴" # Uncertain
|
| 710 |
+
|
| 711 |
+
# Create visualizations for this match
|
| 712 |
+
npz_path = pairwise_dir / f"rank_{rank:02d}_{match['catalog_id']}.npz"
|
| 713 |
+
if npz_path.exists():
|
| 714 |
+
try:
|
| 715 |
+
pairwise_data = np.load(npz_path)
|
| 716 |
+
|
| 717 |
+
# Create matched keypoints visualization
|
| 718 |
+
match_viz = draw_matched_keypoints(
|
| 719 |
+
query_image_path=cropped_path,
|
| 720 |
+
catalog_image_path=catalog_img_path,
|
| 721 |
+
query_keypoints=pairwise_data["query_keypoints"],
|
| 722 |
+
catalog_keypoints=pairwise_data["catalog_keypoints"],
|
| 723 |
+
match_scores=pairwise_data["match_scores"],
|
| 724 |
+
max_matches=100,
|
| 725 |
+
)
|
| 726 |
+
match_visualizations[rank] = match_viz
|
| 727 |
+
|
| 728 |
+
# Create clean comparison visualization
|
| 729 |
+
clean_viz = draw_side_by_side_comparison(
|
| 730 |
+
query_image_path=cropped_path,
|
| 731 |
+
catalog_image_path=catalog_img_path,
|
| 732 |
+
)
|
| 733 |
+
clean_comparison_visualizations[rank] = clean_viz
|
| 734 |
+
except Exception as e:
|
| 735 |
+
logger.error(f"Error creating visualizations for rank {rank}: {e}")
|
| 736 |
+
|
| 737 |
+
# Format for table (as list, not dict)
|
| 738 |
+
dataset_samples.append(
|
| 739 |
+
[
|
| 740 |
+
rank,
|
| 741 |
+
indicator,
|
| 742 |
+
leopard_name.title(),
|
| 743 |
+
location.replace("_", " ").title(),
|
| 744 |
+
f"{wasserstein:.4f}",
|
| 745 |
+
]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Store match visualizations, enriched matches, filters, and temp_dir in global state
|
| 749 |
+
LOADED_MODELS["current_match_visualizations"] = match_visualizations
|
| 750 |
+
LOADED_MODELS["current_clean_comparison_visualizations"] = (
|
| 751 |
+
clean_comparison_visualizations
|
| 752 |
+
)
|
| 753 |
+
LOADED_MODELS["current_enriched_matches"] = matches
|
| 754 |
+
LOADED_MODELS["current_filter_body_parts"] = filter_body_parts
|
| 755 |
+
LOADED_MODELS["current_temp_dir"] = temp_dir
|
| 756 |
+
|
| 757 |
+
# Automatically load rank 1 details (visualizations + galleries)
|
| 758 |
+
rank1_details = load_match_details_for_rank(rank=1)
|
| 759 |
+
|
| 760 |
+
# Return 23 outputs total:
|
| 761 |
+
# - 5 pipeline outputs (result_text, seg_viz, cropped_image, extracted_kpts_viz, dataset_samples)
|
| 762 |
+
# - 18 rank 1 details (from load_match_details_for_rank)
|
| 763 |
+
return (
|
| 764 |
+
result_text, # 1. Top match result text
|
| 765 |
+
seg_viz, # 2. Segmentation overlay
|
| 766 |
+
cropped_image_pil, # 3. Cropped leopard
|
| 767 |
+
extracted_kpts_viz, # 4. Extracted keypoints
|
| 768 |
+
dataset_samples, # 5. Matches table data
|
| 769 |
+
# Unpack all 18 rank 1 details:
|
| 770 |
+
*rank1_details, # 6-23. visualizations, header, indicators, galleries
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
except Exception as e:
|
| 774 |
+
logger.error(f"Error processing image: {e}", exc_info=True)
|
| 775 |
+
# Return 23 empty outputs (5 pipeline + 18 rank 1 details)
|
| 776 |
+
return (
|
| 777 |
+
f"Error processing image: {str(e)}", # 1. result_text
|
| 778 |
+
None, # 2. seg_viz
|
| 779 |
+
None, # 3. cropped_image
|
| 780 |
+
None, # 4. extracted_kpts_viz
|
| 781 |
+
[], # 5. dataset_samples
|
| 782 |
+
gr.update(value=None), # 6. matched_kpts_viz
|
| 783 |
+
gr.update(value=None), # 7. clean_comparison_viz
|
| 784 |
+
gr.update(value=""), # 8. header
|
| 785 |
+
gr.update(value=""), # 9. head indicator
|
| 786 |
+
gr.update(value=""), # 10. left_flank indicator
|
| 787 |
+
gr.update(value=""), # 11. right_flank indicator
|
| 788 |
+
gr.update(value=""), # 12. tail indicator
|
| 789 |
+
gr.update(value=""), # 13. misc indicator
|
| 790 |
+
gr.update(visible=False), # 14. head empty message
|
| 791 |
+
gr.update(visible=False), # 15. left_flank empty message
|
| 792 |
+
gr.update(visible=False), # 16. right_flank empty message
|
| 793 |
+
gr.update(visible=False), # 17. tail empty message
|
| 794 |
+
gr.update(visible=False), # 18. misc empty message
|
| 795 |
+
gr.update(value=[]), # 19. head gallery
|
| 796 |
+
gr.update(value=[]), # 20. left_flank gallery
|
| 797 |
+
gr.update(value=[]), # 21. right_flank gallery
|
| 798 |
+
gr.update(value=[]), # 22. tail gallery
|
| 799 |
+
gr.update(value=[]), # 23. misc gallery
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def create_segmentation_viz(image_path, mask):
|
| 804 |
+
"""Create visualization of segmentation mask overlaid on image."""
|
| 805 |
+
# Load original image
|
| 806 |
+
img = cv2.imread(str(image_path))
|
| 807 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 808 |
+
|
| 809 |
+
# Resize mask to match image dimensions if needed
|
| 810 |
+
if mask.shape[:2] != img_rgb.shape[:2]:
|
| 811 |
+
mask_resized = cv2.resize(
|
| 812 |
+
mask.astype(np.uint8),
|
| 813 |
+
(img_rgb.shape[1], img_rgb.shape[0]),
|
| 814 |
+
interpolation=cv2.INTER_NEAREST,
|
| 815 |
+
)
|
| 816 |
+
else:
|
| 817 |
+
mask_resized = mask
|
| 818 |
+
|
| 819 |
+
# Create colored overlay
|
| 820 |
+
overlay = img_rgb.copy()
|
| 821 |
+
overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region
|
| 822 |
+
|
| 823 |
+
# Blend
|
| 824 |
+
alpha = 0.4
|
| 825 |
+
blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0)
|
| 826 |
+
|
| 827 |
+
return Image.fromarray(blended)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def load_match_details_for_rank(rank: int) -> tuple:
|
| 831 |
+
"""Load all match details (visualizations + galleries) for a specific rank.
|
| 832 |
+
|
| 833 |
+
This is a reusable helper function that encapsulates the logic for loading
|
| 834 |
+
match visualizations, galleries, and metadata for a given rank. Used by both
|
| 835 |
+
the automatic rank 1 display after pipeline completion and the interactive
|
| 836 |
+
row selection handler.
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
rank: The rank to load (1-indexed)
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
Tuple of 18 Gradio component updates:
|
| 843 |
+
(matched_kpts_viz, clean_comparison_viz, header,
|
| 844 |
+
head_indicator, left_flank_indicator, right_flank_indicator, tail_indicator, misc_indicator,
|
| 845 |
+
head_empty_message, left_flank_empty_message, right_flank_empty_message,
|
| 846 |
+
tail_empty_message, misc_empty_message,
|
| 847 |
+
gallery_head, gallery_left_flank, gallery_right_flank, gallery_tail, gallery_misc)
|
| 848 |
+
"""
|
| 849 |
+
# Get stored data from global state
|
| 850 |
+
match_visualizations = LOADED_MODELS.get("current_match_visualizations", {})
|
| 851 |
+
clean_comparison_visualizations = LOADED_MODELS.get(
|
| 852 |
+
"current_clean_comparison_visualizations", {}
|
| 853 |
+
)
|
| 854 |
+
enriched_matches = LOADED_MODELS.get("current_enriched_matches", [])
|
| 855 |
+
filter_body_parts = LOADED_MODELS.get("current_filter_body_parts")
|
| 856 |
+
catalog_root = LOADED_MODELS.get("catalog_root")
|
| 857 |
+
|
| 858 |
+
# Find the match for the requested rank
|
| 859 |
+
selected_match = None
|
| 860 |
+
for match in enriched_matches:
|
| 861 |
+
if match["rank"] == rank:
|
| 862 |
+
selected_match = match
|
| 863 |
+
break
|
| 864 |
+
|
| 865 |
+
if not selected_match or rank not in match_visualizations:
|
| 866 |
+
# Return empty updates for all 18 outputs
|
| 867 |
+
return (
|
| 868 |
+
gr.update(value=None), # 1. matched_kpts_viz
|
| 869 |
+
gr.update(value=None), # 2. clean_comparison_viz
|
| 870 |
+
gr.update(value=""), # 3. header
|
| 871 |
+
gr.update(value=""), # 4. head indicator
|
| 872 |
+
gr.update(value=""), # 5. left_flank indicator
|
| 873 |
+
gr.update(value=""), # 6. right_flank indicator
|
| 874 |
+
gr.update(value=""), # 7. tail indicator
|
| 875 |
+
gr.update(value=""), # 8. misc indicator
|
| 876 |
+
gr.update(visible=False), # 9. head empty message
|
| 877 |
+
gr.update(visible=False), # 10. left_flank empty message
|
| 878 |
+
gr.update(visible=False), # 11. right_flank empty message
|
| 879 |
+
gr.update(visible=False), # 12. tail empty message
|
| 880 |
+
gr.update(visible=False), # 13. misc empty message
|
| 881 |
+
gr.update(value=[]), # 14. head gallery
|
| 882 |
+
gr.update(value=[]), # 15. left_flank gallery
|
| 883 |
+
gr.update(value=[]), # 16. right_flank gallery
|
| 884 |
+
gr.update(value=[]), # 17. tail gallery
|
| 885 |
+
gr.update(value=[]), # 18. misc gallery
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Get both visualizations
|
| 889 |
+
match_viz = match_visualizations[rank]
|
| 890 |
+
clean_viz = clean_comparison_visualizations.get(rank)
|
| 891 |
+
|
| 892 |
+
# Create dynamic header with leopard name
|
| 893 |
+
leopard_name = selected_match["leopard_name"]
|
| 894 |
+
header_text = f"## Reference Images for {leopard_name.title()}"
|
| 895 |
+
|
| 896 |
+
# Load galleries organized by body part
|
| 897 |
+
galleries = {}
|
| 898 |
+
if catalog_root:
|
| 899 |
+
try:
|
| 900 |
+
# Extract location from match filepath
|
| 901 |
+
location = None
|
| 902 |
+
filepath = Path(selected_match["filepath"])
|
| 903 |
+
parts = filepath.parts
|
| 904 |
+
if "database" in parts:
|
| 905 |
+
db_idx = parts.index("database")
|
| 906 |
+
if db_idx + 1 < len(parts):
|
| 907 |
+
location = parts[db_idx + 1]
|
| 908 |
+
|
| 909 |
+
galleries = load_matched_individual_gallery_by_body_part(
|
| 910 |
+
catalog_root=catalog_root,
|
| 911 |
+
leopard_name=leopard_name,
|
| 912 |
+
location=location,
|
| 913 |
+
)
|
| 914 |
+
except Exception as e:
|
| 915 |
+
logger.error(f"Error loading gallery for {leopard_name}: {e}")
|
| 916 |
+
# Initialize empty galleries on error
|
| 917 |
+
galleries = {
|
| 918 |
+
"head": [],
|
| 919 |
+
"left_flank": [],
|
| 920 |
+
"right_flank": [],
|
| 921 |
+
"tail": [],
|
| 922 |
+
"misc": [],
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
# Create emoji indicators for filtered body parts
|
| 926 |
+
def get_indicator(body_part: str) -> str:
|
| 927 |
+
"""Return star if body part was in filter, empty string otherwise."""
|
| 928 |
+
if filter_body_parts and body_part in filter_body_parts:
|
| 929 |
+
return "* (filtered)"
|
| 930 |
+
return ""
|
| 931 |
+
|
| 932 |
+
# Helper to determine if empty message should be visible
|
| 933 |
+
def is_empty(body_part: str) -> bool:
|
| 934 |
+
"""Return True if no images for this body part."""
|
| 935 |
+
return len(galleries.get(body_part, [])) == 0
|
| 936 |
+
|
| 937 |
+
return (
|
| 938 |
+
gr.update(value=match_viz), # 1. matched_kpts_viz
|
| 939 |
+
gr.update(value=clean_viz), # 2. clean_comparison_viz
|
| 940 |
+
gr.update(value=header_text), # 3. header
|
| 941 |
+
gr.update(value=get_indicator("head")), # 4. head indicator
|
| 942 |
+
gr.update(value=get_indicator("left_flank")), # 5. left_flank indicator
|
| 943 |
+
gr.update(value=get_indicator("right_flank")), # 6. right_flank indicator
|
| 944 |
+
gr.update(value=get_indicator("tail")), # 7. tail indicator
|
| 945 |
+
gr.update(value=get_indicator("misc")), # 8. misc indicator
|
| 946 |
+
gr.update(visible=is_empty("head")), # 9. head empty message
|
| 947 |
+
gr.update(visible=is_empty("left_flank")), # 10. left_flank empty message
|
| 948 |
+
gr.update(visible=is_empty("right_flank")), # 11. right_flank empty message
|
| 949 |
+
gr.update(visible=is_empty("tail")), # 12. tail empty message
|
| 950 |
+
gr.update(visible=is_empty("misc")), # 13. misc empty message
|
| 951 |
+
gr.update(
|
| 952 |
+
value=galleries.get("head", []), visible=not is_empty("head")
|
| 953 |
+
), # 14. head gallery
|
| 954 |
+
gr.update(
|
| 955 |
+
value=galleries.get("left_flank", []), visible=not is_empty("left_flank")
|
| 956 |
+
), # 15. left_flank gallery
|
| 957 |
+
gr.update(
|
| 958 |
+
value=galleries.get("right_flank", []), visible=not is_empty("right_flank")
|
| 959 |
+
), # 16. right_flank gallery
|
| 960 |
+
gr.update(
|
| 961 |
+
value=galleries.get("tail", []), visible=not is_empty("tail")
|
| 962 |
+
), # 17. tail gallery
|
| 963 |
+
gr.update(
|
| 964 |
+
value=galleries.get("misc", []), visible=not is_empty("misc")
|
| 965 |
+
), # 18. misc gallery
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def on_match_selected(evt: gr.SelectData):
|
| 970 |
+
"""Handle selection of a match from the dataset table.
|
| 971 |
+
|
| 972 |
+
Returns both visualizations, header, indicators, empty messages,
|
| 973 |
+
and galleries organized by body part.
|
| 974 |
+
"""
|
| 975 |
+
# evt.index is [row, col] for Dataframe, we want row
|
| 976 |
+
if isinstance(evt.index, (list, tuple)):
|
| 977 |
+
selected_row = evt.index[0]
|
| 978 |
+
else:
|
| 979 |
+
selected_row = evt.index
|
| 980 |
+
|
| 981 |
+
selected_rank = selected_row + 1 # Ranks are 1-indexed
|
| 982 |
+
|
| 983 |
+
# Delegate to the reusable helper function
|
| 984 |
+
return load_match_details_for_rank(selected_rank)
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
def load_matched_individual_gallery_by_body_part(
|
| 988 |
+
catalog_root: Path,
|
| 989 |
+
leopard_name: str,
|
| 990 |
+
location: str | None = None,
|
| 991 |
+
) -> dict[str, list[tuple]]:
|
| 992 |
+
"""Load all images for a matched individual organized by body part.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
catalog_root: Path to catalog root directory
|
| 996 |
+
leopard_name: Name of the matched individual (e.g., "karindas")
|
| 997 |
+
location: Geographic location (e.g., "skycrest_valley")
|
| 998 |
+
|
| 999 |
+
Returns:
|
| 1000 |
+
Dict mapping body part to list of (PIL.Image, caption) tuples:
|
| 1001 |
+
{
|
| 1002 |
+
"head": [(img1, caption1), (img2, caption2), ...],
|
| 1003 |
+
"left_flank": [...],
|
| 1004 |
+
"right_flank": [...],
|
| 1005 |
+
"tail": [...],
|
| 1006 |
+
"misc": [...]
|
| 1007 |
+
}
|
| 1008 |
+
"""
|
| 1009 |
+
# Initialize dict with all body parts
|
| 1010 |
+
galleries = {
|
| 1011 |
+
"head": [],
|
| 1012 |
+
"left_flank": [],
|
| 1013 |
+
"right_flank": [],
|
| 1014 |
+
"tail": [],
|
| 1015 |
+
"misc": [],
|
| 1016 |
+
}
|
| 1017 |
+
|
| 1018 |
+
# Find metadata path: database/{location}/{individual}/metadata.yaml
|
| 1019 |
+
if location:
|
| 1020 |
+
metadata_path = (
|
| 1021 |
+
catalog_root / "database" / location / leopard_name / "metadata.yaml"
|
| 1022 |
+
)
|
| 1023 |
+
else:
|
| 1024 |
+
# Try to find the individual in any location
|
| 1025 |
+
metadata_path = None
|
| 1026 |
+
database_dir = catalog_root / "database"
|
| 1027 |
+
if database_dir.exists():
|
| 1028 |
+
for loc_dir in database_dir.iterdir():
|
| 1029 |
+
if loc_dir.is_dir():
|
| 1030 |
+
potential_path = loc_dir / leopard_name / "metadata.yaml"
|
| 1031 |
+
if potential_path.exists():
|
| 1032 |
+
metadata_path = potential_path
|
| 1033 |
+
break
|
| 1034 |
+
|
| 1035 |
+
if not metadata_path or not metadata_path.exists():
|
| 1036 |
+
logger.warning(f"Metadata not found for {leopard_name}")
|
| 1037 |
+
return galleries
|
| 1038 |
+
|
| 1039 |
+
try:
|
| 1040 |
+
metadata = load_leopard_metadata(metadata_path)
|
| 1041 |
+
|
| 1042 |
+
# Load all images organized by body part
|
| 1043 |
+
for img_entry in metadata["reference_images"]:
|
| 1044 |
+
body_part = img_entry.get("body_part", "misc")
|
| 1045 |
+
|
| 1046 |
+
# Normalize body_part to match our keys
|
| 1047 |
+
if body_part not in galleries:
|
| 1048 |
+
body_part = "misc" # Default to misc if unknown
|
| 1049 |
+
|
| 1050 |
+
# Load image
|
| 1051 |
+
img_path = catalog_root / "database" / img_entry["path"]
|
| 1052 |
+
|
| 1053 |
+
try:
|
| 1054 |
+
img = Image.open(img_path)
|
| 1055 |
+
# Simple caption: just body part name
|
| 1056 |
+
caption = body_part
|
| 1057 |
+
galleries[body_part].append((img, caption))
|
| 1058 |
+
except Exception as e:
|
| 1059 |
+
logger.error(f"Error loading image {img_path}: {e}")
|
| 1060 |
+
|
| 1061 |
+
except Exception as e:
|
| 1062 |
+
logger.error(f"Error loading metadata for {leopard_name}: {e}")
|
| 1063 |
+
|
| 1064 |
+
return galleries
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
def cleanup_temp_files():
|
| 1068 |
+
"""Clean up temporary files from previous run."""
|
| 1069 |
+
temp_dir = LOADED_MODELS.get("current_temp_dir")
|
| 1070 |
+
if temp_dir and temp_dir.exists():
|
| 1071 |
+
try:
|
| 1072 |
+
shutil.rmtree(temp_dir)
|
| 1073 |
+
logger.info(f"Cleaned up temporary directory: {temp_dir}")
|
| 1074 |
+
except Exception as e:
|
| 1075 |
+
logger.warning(f"Error cleaning up temp directory: {e}")
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
def create_leopard_tab(leopard_metadata, config: AppConfig):
|
| 1079 |
+
"""Create a tab for displaying a single leopard's images.
|
| 1080 |
+
|
| 1081 |
+
Args:
|
| 1082 |
+
leopard_metadata: Metadata dictionary for the leopard individual
|
| 1083 |
+
config: Application configuration
|
| 1084 |
+
"""
|
| 1085 |
+
# Support both 'leopard_name' and 'individual_name' keys
|
| 1086 |
+
leopard_name = leopard_metadata.get("leopard_name") or leopard_metadata.get(
|
| 1087 |
+
"individual_name"
|
| 1088 |
+
)
|
| 1089 |
+
location = leopard_metadata.get("location", "unknown")
|
| 1090 |
+
total_images = leopard_metadata["statistics"]["total_reference_images"]
|
| 1091 |
+
|
| 1092 |
+
# Get body parts from statistics
|
| 1093 |
+
body_parts = leopard_metadata["statistics"].get(
|
| 1094 |
+
"body_parts_represented", leopard_metadata["statistics"].get("body_parts", [])
|
| 1095 |
+
)
|
| 1096 |
+
body_parts_str = ", ".join(body_parts) if body_parts else "N/A"
|
| 1097 |
+
|
| 1098 |
+
with gr.Tab(f"{leopard_name}"):
|
| 1099 |
+
# Header with statistics
|
| 1100 |
+
gr.Markdown(
|
| 1101 |
+
f"### {leopard_name.title()}\n"
|
| 1102 |
+
f"**Location:** {location.replace('_', ' ').title()} | "
|
| 1103 |
+
f"**{total_images} images** | "
|
| 1104 |
+
f"**Body parts:** {body_parts_str}"
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# Load all images with body_part captions
|
| 1108 |
+
gallery_data = []
|
| 1109 |
+
for img_entry in leopard_metadata["reference_images"]:
|
| 1110 |
+
img_path = config.catalog_root / "database" / img_entry["path"]
|
| 1111 |
+
body_part = img_entry.get("body_part", "unknown")
|
| 1112 |
+
try:
|
| 1113 |
+
img = Image.open(img_path)
|
| 1114 |
+
# Caption format: just body_part (location is already in tab)
|
| 1115 |
+
caption = body_part
|
| 1116 |
+
gallery_data.append((img, caption))
|
| 1117 |
+
except Exception as e:
|
| 1118 |
+
logger.error(f"Error loading image {img_path}: {e}")
|
| 1119 |
+
|
| 1120 |
+
# Display gallery
|
| 1121 |
+
gr.Gallery(
|
| 1122 |
+
value=gallery_data,
|
| 1123 |
+
label=f"Reference Images for {leopard_name.title()}",
|
| 1124 |
+
columns=6,
|
| 1125 |
+
height=700,
|
| 1126 |
+
object_fit="scale-down",
|
| 1127 |
+
allow_preview=True,
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
def create_app(config: AppConfig):
|
| 1132 |
+
"""Create and configure the Gradio application.
|
| 1133 |
+
|
| 1134 |
+
Args:
|
| 1135 |
+
config: Application configuration
|
| 1136 |
+
"""
|
| 1137 |
+
# Extract data archives on first run (for HF Spaces deployment)
|
| 1138 |
+
ensure_data_extracted()
|
| 1139 |
+
|
| 1140 |
+
# Initialize models at startup
|
| 1141 |
+
initialize_models(config)
|
| 1142 |
+
|
| 1143 |
+
# Load catalog data
|
| 1144 |
+
catalog_index, individuals_data = load_catalog_data(config)
|
| 1145 |
+
|
| 1146 |
+
# Build example images list from examples directory
|
| 1147 |
+
example_images = (
|
| 1148 |
+
list(config.examples_dir.glob("*.jpg"))
|
| 1149 |
+
+ list(config.examples_dir.glob("*.JPG"))
|
| 1150 |
+
+ list(config.examples_dir.glob("*.png"))
|
| 1151 |
+
)
|
| 1152 |
+
# Sort with Ayima images last
|
| 1153 |
+
example_images.sort(key=lambda x: (1 if "Ayima" in x.name else 0, x.name))
|
| 1154 |
+
|
| 1155 |
+
# Create interface
|
| 1156 |
+
with gr.Blocks(title="Snow Leopard Identification") as app:
|
| 1157 |
+
# Hidden state to track which example image was selected (for cache lookup)
|
| 1158 |
+
selected_example_state = gr.State(value=None)
|
| 1159 |
+
|
| 1160 |
+
gr.HTML("""
|
| 1161 |
+
<div style="text-align: center; margin-bottom: 20px;">
|
| 1162 |
+
<h1 style="margin-bottom: 10px;">Snow Leopard Identification</h1>
|
| 1163 |
+
<p style="font-size: 16px; color: #666;">
|
| 1164 |
+
Computer vision system for identifying individual snow leopards.
|
| 1165 |
+
</p>
|
| 1166 |
+
</div>
|
| 1167 |
+
""")
|
| 1168 |
+
|
| 1169 |
+
# Main tabs
|
| 1170 |
+
with gr.Tabs():
|
| 1171 |
+
# Tab 1: Identify Snow Leopard
|
| 1172 |
+
with gr.Tab("Identify Snow Leopard"):
|
| 1173 |
+
gr.Markdown("""
|
| 1174 |
+
Upload a snow leopard image or select an example to identify which individual it is.
|
| 1175 |
+
The system will detect the leopard, extract distinctive features, and match against the catalog.
|
| 1176 |
+
""")
|
| 1177 |
+
|
| 1178 |
+
with gr.Row():
|
| 1179 |
+
# Left column: Input
|
| 1180 |
+
with gr.Column(scale=1):
|
| 1181 |
+
image_input = gr.Image(
|
| 1182 |
+
type="pil",
|
| 1183 |
+
label="Upload Snow Leopard Image",
|
| 1184 |
+
sources=["upload", "clipboard"],
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
examples_component = gr.Examples(
|
| 1188 |
+
examples=[[str(img)] for img in example_images],
|
| 1189 |
+
inputs=image_input,
|
| 1190 |
+
label="Example Images",
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
# Track example selection for cache lookup
|
| 1194 |
+
def on_example_select(evt: gr.SelectData):
|
| 1195 |
+
"""Update state when an example is selected."""
|
| 1196 |
+
if evt.index is not None:
|
| 1197 |
+
return str(example_images[evt.index])
|
| 1198 |
+
return None
|
| 1199 |
+
|
| 1200 |
+
# When image changes, check if it matches an example
|
| 1201 |
+
def check_if_example(img):
|
| 1202 |
+
"""Check if uploaded image matches an example path."""
|
| 1203 |
+
# When user uploads a new image, clear the example state
|
| 1204 |
+
# Examples component handles setting state via select event
|
| 1205 |
+
return gr.update() # No change to state on image change
|
| 1206 |
+
|
| 1207 |
+
examples_component.dataset.select(
|
| 1208 |
+
fn=on_example_select,
|
| 1209 |
+
outputs=[selected_example_state],
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
# Clear example state when user uploads a new image
|
| 1213 |
+
image_input.upload(
|
| 1214 |
+
fn=lambda: None,
|
| 1215 |
+
outputs=[selected_example_state],
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
# Location filter dropdown
|
| 1219 |
+
available_locations = get_available_locations(
|
| 1220 |
+
config.catalog_root
|
| 1221 |
+
)
|
| 1222 |
+
location_filter = gr.Dropdown(
|
| 1223 |
+
choices=available_locations,
|
| 1224 |
+
value=["all"],
|
| 1225 |
+
multiselect=True,
|
| 1226 |
+
label="Filter by Location",
|
| 1227 |
+
info="Select locations to search (default: all locations)",
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
# Body part filter dropdown
|
| 1231 |
+
available_body_parts = get_available_body_parts(
|
| 1232 |
+
config.catalog_root
|
| 1233 |
+
)
|
| 1234 |
+
body_part_filter = gr.Dropdown(
|
| 1235 |
+
choices=available_body_parts,
|
| 1236 |
+
value=["all"],
|
| 1237 |
+
multiselect=True,
|
| 1238 |
+
label="Filter by Body Part",
|
| 1239 |
+
info="Select body parts to match (default: all body parts)",
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
# Advanced Configuration Accordion
|
| 1243 |
+
with gr.Accordion("Advanced Configuration", open=False):
|
| 1244 |
+
# Feature extractor dropdown
|
| 1245 |
+
available_extractors = get_available_extractors(
|
| 1246 |
+
config.catalog_root
|
| 1247 |
+
)
|
| 1248 |
+
extractor_dropdown = gr.Dropdown(
|
| 1249 |
+
choices=available_extractors,
|
| 1250 |
+
value="sift"
|
| 1251 |
+
if "sift" in available_extractors
|
| 1252 |
+
else (
|
| 1253 |
+
available_extractors[0]
|
| 1254 |
+
if available_extractors
|
| 1255 |
+
else "sift"
|
| 1256 |
+
),
|
| 1257 |
+
label="Feature Extractor",
|
| 1258 |
+
info=f"Available: {', '.join(available_extractors)}",
|
| 1259 |
+
scale=1,
|
| 1260 |
+
)
|
| 1261 |
+
|
| 1262 |
+
# Top-K parameter
|
| 1263 |
+
top_k_input = gr.Number(
|
| 1264 |
+
value=config.top_k,
|
| 1265 |
+
label="Top-K Matches",
|
| 1266 |
+
info="Number of top matches to return",
|
| 1267 |
+
minimum=1,
|
| 1268 |
+
maximum=20,
|
| 1269 |
+
step=1,
|
| 1270 |
+
precision=0,
|
| 1271 |
+
scale=1,
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
submit_btn = gr.Button(
|
| 1275 |
+
value="Identify Snow Leopard",
|
| 1276 |
+
variant="primary",
|
| 1277 |
+
size="lg",
|
| 1278 |
+
)
|
| 1279 |
+
|
| 1280 |
+
# Right column: Results
|
| 1281 |
+
with gr.Column(scale=4):
|
| 1282 |
+
# Top-1 prediction
|
| 1283 |
+
result_text = gr.Markdown("")
|
| 1284 |
+
|
| 1285 |
+
# Tabs for different result views
|
| 1286 |
+
with gr.Tabs():
|
| 1287 |
+
with gr.Tab("Model Internals"):
|
| 1288 |
+
gr.Markdown("""
|
| 1289 |
+
View the internal processing steps: segmentation mask, cropped leopard, and extracted keypoints.
|
| 1290 |
+
""")
|
| 1291 |
+
with gr.Row():
|
| 1292 |
+
seg_viz = gr.Image(
|
| 1293 |
+
label="Segmentation Overlay",
|
| 1294 |
+
type="pil",
|
| 1295 |
+
)
|
| 1296 |
+
cropped_image = gr.Image(
|
| 1297 |
+
label="Extracted Snow Leopard",
|
| 1298 |
+
type="pil",
|
| 1299 |
+
)
|
| 1300 |
+
extracted_kpts_viz = gr.Image(
|
| 1301 |
+
label="Extracted Keypoints",
|
| 1302 |
+
type="pil",
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
with gr.Tab("Top Matches"):
|
| 1306 |
+
gr.Markdown("""
|
| 1307 |
+
Click a row to view detailed feature matching visualization and all reference images for that leopard.
|
| 1308 |
+
|
| 1309 |
+
**Higher Wasserstein distance = better match** (typical range: 0.04-0.27)
|
| 1310 |
+
|
| 1311 |
+
**Confidence Levels:** 🔵 Excellent (>=0.12) | 🟢 Good (>=0.07) | 🟡 Fair (>=0.04) | 🔴 Uncertain (<0.04)
|
| 1312 |
+
""")
|
| 1313 |
+
|
| 1314 |
+
matches_dataset = gr.Dataframe(
|
| 1315 |
+
headers=[
|
| 1316 |
+
"Rank",
|
| 1317 |
+
"Confidence",
|
| 1318 |
+
"Leopard Name",
|
| 1319 |
+
"Location",
|
| 1320 |
+
"Wasserstein",
|
| 1321 |
+
],
|
| 1322 |
+
label="Top Matches",
|
| 1323 |
+
wrap=True,
|
| 1324 |
+
col_count=(5, "fixed"),
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
# Visualization container (always visible, images populated on pipeline completion)
|
| 1328 |
+
with gr.Column() as viz_tabs:
|
| 1329 |
+
# Tabbed visualization views
|
| 1330 |
+
with gr.Tabs():
|
| 1331 |
+
with gr.Tab("Matched Keypoints"):
|
| 1332 |
+
gr.Markdown(
|
| 1333 |
+
"Feature matching with keypoints and confidence-coded connecting lines. "
|
| 1334 |
+
"**Green** = high confidence, **Yellow** = medium, **Red** = low."
|
| 1335 |
+
)
|
| 1336 |
+
matched_kpts_viz = gr.Image(
|
| 1337 |
+
type="pil",
|
| 1338 |
+
show_label=False,
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
with gr.Tab("Clean Comparison"):
|
| 1342 |
+
gr.Markdown(
|
| 1343 |
+
"Side-by-side comparison without feature annotations. "
|
| 1344 |
+
"Useful for assessing overall visual similarity and spotting patterns."
|
| 1345 |
+
)
|
| 1346 |
+
clean_comparison_viz = gr.Image(
|
| 1347 |
+
type="pil",
|
| 1348 |
+
show_label=False,
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
# Dynamic header showing matched leopard name
|
| 1352 |
+
selected_match_header = gr.Markdown("", visible=True)
|
| 1353 |
+
|
| 1354 |
+
# Create tabs for each body part
|
| 1355 |
+
with gr.Tabs():
|
| 1356 |
+
with gr.Tab("Head"):
|
| 1357 |
+
head_indicator = gr.Markdown("")
|
| 1358 |
+
head_empty_message = gr.Markdown(
|
| 1359 |
+
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
|
| 1360 |
+
'<p style="font-size: 16px;">No reference images available for this body part</p>'
|
| 1361 |
+
"</div>",
|
| 1362 |
+
visible=False,
|
| 1363 |
+
)
|
| 1364 |
+
gallery_head = gr.Gallery(
|
| 1365 |
+
columns=6,
|
| 1366 |
+
height=400,
|
| 1367 |
+
object_fit="scale-down",
|
| 1368 |
+
allow_preview=True,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
with gr.Tab("Left Flank"):
|
| 1372 |
+
left_flank_indicator = gr.Markdown("")
|
| 1373 |
+
left_flank_empty_message = gr.Markdown(
|
| 1374 |
+
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
|
| 1375 |
+
'<p style="font-size: 16px;">No reference images available for this body part</p>'
|
| 1376 |
+
"</div>",
|
| 1377 |
+
visible=False,
|
| 1378 |
+
)
|
| 1379 |
+
gallery_left_flank = gr.Gallery(
|
| 1380 |
+
columns=6,
|
| 1381 |
+
height=400,
|
| 1382 |
+
object_fit="scale-down",
|
| 1383 |
+
allow_preview=True,
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
with gr.Tab("Right Flank"):
|
| 1387 |
+
right_flank_indicator = gr.Markdown("")
|
| 1388 |
+
right_flank_empty_message = gr.Markdown(
|
| 1389 |
+
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
|
| 1390 |
+
'<p style="font-size: 16px;">No reference images available for this body part</p>'
|
| 1391 |
+
"</div>",
|
| 1392 |
+
visible=False,
|
| 1393 |
+
)
|
| 1394 |
+
gallery_right_flank = gr.Gallery(
|
| 1395 |
+
columns=6,
|
| 1396 |
+
height=400,
|
| 1397 |
+
object_fit="scale-down",
|
| 1398 |
+
allow_preview=True,
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
with gr.Tab("Tail"):
|
| 1402 |
+
tail_indicator = gr.Markdown("")
|
| 1403 |
+
tail_empty_message = gr.Markdown(
|
| 1404 |
+
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
|
| 1405 |
+
'<p style="font-size: 16px;">No reference images available for this body part</p>'
|
| 1406 |
+
"</div>",
|
| 1407 |
+
visible=False,
|
| 1408 |
+
)
|
| 1409 |
+
gallery_tail = gr.Gallery(
|
| 1410 |
+
columns=6,
|
| 1411 |
+
height=400,
|
| 1412 |
+
object_fit="scale-down",
|
| 1413 |
+
allow_preview=True,
|
| 1414 |
+
)
|
| 1415 |
+
|
| 1416 |
+
with gr.Tab("Other"):
|
| 1417 |
+
misc_indicator = gr.Markdown("")
|
| 1418 |
+
misc_empty_message = gr.Markdown(
|
| 1419 |
+
value='<div style="text-align: center; padding: 60px 20px; color: #888;">'
|
| 1420 |
+
'<p style="font-size: 16px;">No reference images available for this body part</p>'
|
| 1421 |
+
"</div>",
|
| 1422 |
+
visible=False,
|
| 1423 |
+
)
|
| 1424 |
+
gallery_misc = gr.Gallery(
|
| 1425 |
+
columns=6,
|
| 1426 |
+
height=400,
|
| 1427 |
+
object_fit="scale-down",
|
| 1428 |
+
allow_preview=True,
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
# Connect submit button
|
| 1432 |
+
submit_btn.click(
|
| 1433 |
+
fn=lambda img, ext, top_k, locs, parts, ex_path: run_identification(
|
| 1434 |
+
image=img,
|
| 1435 |
+
extractor=ext,
|
| 1436 |
+
top_k=int(top_k),
|
| 1437 |
+
selected_locations=locs,
|
| 1438 |
+
selected_body_parts=parts,
|
| 1439 |
+
example_path=ex_path,
|
| 1440 |
+
config=config,
|
| 1441 |
+
),
|
| 1442 |
+
inputs=[
|
| 1443 |
+
image_input,
|
| 1444 |
+
extractor_dropdown,
|
| 1445 |
+
top_k_input,
|
| 1446 |
+
location_filter,
|
| 1447 |
+
body_part_filter,
|
| 1448 |
+
selected_example_state,
|
| 1449 |
+
],
|
| 1450 |
+
outputs=[
|
| 1451 |
+
# Pipeline outputs (5 total)
|
| 1452 |
+
result_text,
|
| 1453 |
+
seg_viz,
|
| 1454 |
+
cropped_image,
|
| 1455 |
+
extracted_kpts_viz,
|
| 1456 |
+
matches_dataset,
|
| 1457 |
+
# Rank 1 auto-display outputs (18 total)
|
| 1458 |
+
matched_kpts_viz,
|
| 1459 |
+
clean_comparison_viz,
|
| 1460 |
+
selected_match_header,
|
| 1461 |
+
head_indicator,
|
| 1462 |
+
left_flank_indicator,
|
| 1463 |
+
right_flank_indicator,
|
| 1464 |
+
tail_indicator,
|
| 1465 |
+
misc_indicator,
|
| 1466 |
+
head_empty_message,
|
| 1467 |
+
left_flank_empty_message,
|
| 1468 |
+
right_flank_empty_message,
|
| 1469 |
+
tail_empty_message,
|
| 1470 |
+
misc_empty_message,
|
| 1471 |
+
gallery_head,
|
| 1472 |
+
gallery_left_flank,
|
| 1473 |
+
gallery_right_flank,
|
| 1474 |
+
gallery_tail,
|
| 1475 |
+
gallery_misc,
|
| 1476 |
+
],
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
# Connect dataset selection
|
| 1480 |
+
matches_dataset.select(
|
| 1481 |
+
fn=on_match_selected,
|
| 1482 |
+
outputs=[
|
| 1483 |
+
matched_kpts_viz,
|
| 1484 |
+
clean_comparison_viz,
|
| 1485 |
+
selected_match_header,
|
| 1486 |
+
head_indicator,
|
| 1487 |
+
left_flank_indicator,
|
| 1488 |
+
right_flank_indicator,
|
| 1489 |
+
tail_indicator,
|
| 1490 |
+
misc_indicator,
|
| 1491 |
+
head_empty_message,
|
| 1492 |
+
left_flank_empty_message,
|
| 1493 |
+
right_flank_empty_message,
|
| 1494 |
+
tail_empty_message,
|
| 1495 |
+
misc_empty_message,
|
| 1496 |
+
gallery_head,
|
| 1497 |
+
gallery_left_flank,
|
| 1498 |
+
gallery_right_flank,
|
| 1499 |
+
gallery_tail,
|
| 1500 |
+
gallery_misc,
|
| 1501 |
+
],
|
| 1502 |
+
)
|
| 1503 |
+
|
| 1504 |
+
# Tab 2: Explore Catalog
|
| 1505 |
+
with gr.Tab("Explore Catalog"):
|
| 1506 |
+
gr.Markdown(
|
| 1507 |
+
"""
|
| 1508 |
+
## Snow Leopard Catalog Browser
|
| 1509 |
+
|
| 1510 |
+
Browse the reference catalog of known snow leopard individuals.
|
| 1511 |
+
Each individual has multiple reference images from different body parts and locations.
|
| 1512 |
+
"""
|
| 1513 |
+
)
|
| 1514 |
+
|
| 1515 |
+
# Display catalog statistics
|
| 1516 |
+
stats = catalog_index.get("statistics", {})
|
| 1517 |
+
formatted_locations = [loc.replace("_", " ").title() for loc in stats.get("locations", [])]
|
| 1518 |
+
gr.Markdown(
|
| 1519 |
+
f"""
|
| 1520 |
+
### Catalog Statistics
|
| 1521 |
+
- **Total Individuals:** {stats.get("total_individuals", "N/A")}
|
| 1522 |
+
- **Total Images:** {stats.get("total_reference_images", "N/A")}
|
| 1523 |
+
- **Locations:** {", ".join(formatted_locations)}
|
| 1524 |
+
- **Body Parts:** {", ".join(stats.get("body_parts", []))}
|
| 1525 |
+
"""
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
gr.Markdown("---")
|
| 1529 |
+
gr.Markdown("### Individual Leopards by Location")
|
| 1530 |
+
|
| 1531 |
+
# Group individuals by location
|
| 1532 |
+
individuals_by_location = {}
|
| 1533 |
+
for individual_data in individuals_data:
|
| 1534 |
+
location = individual_data.get("location", "unknown")
|
| 1535 |
+
if location not in individuals_by_location:
|
| 1536 |
+
individuals_by_location[location] = []
|
| 1537 |
+
individuals_by_location[location].append(individual_data)
|
| 1538 |
+
|
| 1539 |
+
# Create tabs for each location
|
| 1540 |
+
with gr.Tabs():
|
| 1541 |
+
for location in sorted(individuals_by_location.keys()):
|
| 1542 |
+
with gr.Tab(f"{location.replace('_', ' ').title()}"):
|
| 1543 |
+
# Create subtabs for each individual in this location
|
| 1544 |
+
with gr.Tabs():
|
| 1545 |
+
for leopard_data in individuals_by_location[location]:
|
| 1546 |
+
create_leopard_tab(
|
| 1547 |
+
leopard_metadata=leopard_data, config=config
|
| 1548 |
+
)
|
| 1549 |
+
|
| 1550 |
+
# Cleanup on app close
|
| 1551 |
+
app.unload(cleanup_temp_files)
|
| 1552 |
+
|
| 1553 |
+
# Load first example image on startup
|
| 1554 |
+
def load_first_example():
|
| 1555 |
+
"""Load the first example image when the app starts."""
|
| 1556 |
+
if example_images:
|
| 1557 |
+
try:
|
| 1558 |
+
first_image = Image.open(example_images[0])
|
| 1559 |
+
return first_image
|
| 1560 |
+
except Exception as e:
|
| 1561 |
+
logger.error(f"Error loading first example image: {e}")
|
| 1562 |
+
return None
|
| 1563 |
+
return None
|
| 1564 |
+
|
| 1565 |
+
app.load(fn=load_first_example, outputs=[image_input])
|
| 1566 |
+
|
| 1567 |
+
return app
|
| 1568 |
+
|
| 1569 |
+
|
| 1570 |
+
if __name__ == "__main__":
|
| 1571 |
+
# Ensure SAM model is downloaded
|
| 1572 |
+
logger.info("Checking for SAM HQ model...")
|
| 1573 |
+
sam_path = ensure_sam_model()
|
| 1574 |
+
|
| 1575 |
+
# Validate required directories exist
|
| 1576 |
+
if not CATALOG_ROOT.exists():
|
| 1577 |
+
logger.error(f"Catalog not found: {CATALOG_ROOT}")
|
| 1578 |
+
logger.error("Please ensure catalog data is present in data/catalog/")
|
| 1579 |
+
exit(1)
|
| 1580 |
+
|
| 1581 |
+
if not EXAMPLES_DIR.exists():
|
| 1582 |
+
logger.warning(f"Examples directory not found: {EXAMPLES_DIR}")
|
| 1583 |
+
EXAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
| 1584 |
+
|
| 1585 |
+
# Create config
|
| 1586 |
+
config = AppConfig(
|
| 1587 |
+
model_path=None, # Not using YOLO
|
| 1588 |
+
catalog_root=CATALOG_ROOT,
|
| 1589 |
+
examples_dir=EXAMPLES_DIR,
|
| 1590 |
+
top_k=TOP_K_DEFAULT,
|
| 1591 |
+
port=7860,
|
| 1592 |
+
share=False,
|
| 1593 |
+
sam_checkpoint_path=sam_path,
|
| 1594 |
+
sam_model_type=SAM_MODEL_TYPE,
|
| 1595 |
+
gdino_model_id=GDINO_MODEL_ID,
|
| 1596 |
+
text_prompt=TEXT_PROMPT,
|
| 1597 |
+
)
|
| 1598 |
+
|
| 1599 |
+
# Build and launch app
|
| 1600 |
+
logger.info("Building Gradio interface...")
|
| 1601 |
+
app = create_app(config)
|
| 1602 |
+
|
| 1603 |
+
logger.info("Launching app...")
|
| 1604 |
+
app.launch(
|
| 1605 |
+
server_name="0.0.0.0",
|
| 1606 |
+
server_port=7860,
|
| 1607 |
+
share=False,
|
| 1608 |
+
)
|
data/cache.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d9e2f29be8e2e38d250becadb88625bb598a9a7359e1107a90b48edbfadddcc
|
| 3 |
+
size 213437895
|
data/catalog.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:034ffb5eec607ebcce6372a725d947f0607866d9d89272e2d20b32604ffcfbe7
|
| 3 |
+
size 276174890
|
data/examples/07070305 Agim.JPG
ADDED
|
|
Git LFS Details
|
data/examples/08190121 Karindas.JPG
ADDED
|
|
Git LFS Details
|
data/examples/08190742 Ayima.jpg
ADDED
|
Git LFS Details
|
data/examples/09150237 AIKA.JPG
ADDED
|
|
Git LFS Details
|
data/examples/IMG_7189 Ayima.JPG
ADDED
|
|
Git LFS Details
|
pyproject.toml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "snowleopard-reid-gradio"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Snow Leopard Re-Identification Gradio App for Hugging Face Spaces"
|
| 5 |
+
requires-python = ">=3.11"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"gradio>=5.49.1",
|
| 8 |
+
"torch>=2.0.0",
|
| 9 |
+
"transformers>=4.30.0",
|
| 10 |
+
"timm>=0.9.0",
|
| 11 |
+
"ultralytics>=8.3.78",
|
| 12 |
+
"segment-anything-hq>=0.3.0",
|
| 13 |
+
"lightglue @ git+https://github.com/cvg/LightGlue.git",
|
| 14 |
+
"opencv-python>=4.11.0.86",
|
| 15 |
+
"numpy>=2.0.0",
|
| 16 |
+
"pillow>=11.0.0",
|
| 17 |
+
"pydantic>=2.0.0",
|
| 18 |
+
"pyyaml>=6.0.0",
|
| 19 |
+
"matplotlib>=3.8.0",
|
| 20 |
+
"scipy>=1.10.0",
|
| 21 |
+
"huggingface_hub>=0.20.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[tool.hatch.build.targets.wheel]
|
| 25 |
+
packages = ["src/snowleopard_reid"]
|
| 26 |
+
|
| 27 |
+
[tool.hatch.metadata]
|
| 28 |
+
allow-direct-references = true
|
| 29 |
+
|
| 30 |
+
[build-system]
|
| 31 |
+
requires = ["hatchling"]
|
| 32 |
+
build-backend = "hatchling.build"
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
timm>=0.9.0
|
| 5 |
+
ultralytics>=8.3.78
|
| 6 |
+
segment-anything-hq>=0.3.0
|
| 7 |
+
|
| 8 |
+
# Feature matching (git install)
|
| 9 |
+
lightglue @ git+https://github.com/cvg/LightGlue.git
|
| 10 |
+
|
| 11 |
+
# Web UI
|
| 12 |
+
gradio>=5.49.1
|
| 13 |
+
|
| 14 |
+
# Image processing
|
| 15 |
+
opencv-python>=4.11.0.86
|
| 16 |
+
numpy>=2.0.0
|
| 17 |
+
pillow>=11.0.0
|
| 18 |
+
matplotlib>=3.8.0
|
| 19 |
+
|
| 20 |
+
# Utilities
|
| 21 |
+
pydantic>=2.0.0
|
| 22 |
+
pyyaml>=6.0.0
|
| 23 |
+
scipy>=1.10.0
|
| 24 |
+
huggingface_hub>=0.20.0
|
scripts/create_archives.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Create compressed archives of catalog and cache data.
|
| 3 |
+
|
| 4 |
+
This script packages the catalog and cached_results directories into
|
| 5 |
+
tar.gz archives for efficient storage in Git LFS.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Create both archives
|
| 9 |
+
python scripts/create_archives.py
|
| 10 |
+
|
| 11 |
+
# Create only catalog archive
|
| 12 |
+
python scripts/create_archives.py --catalog-only
|
| 13 |
+
|
| 14 |
+
# Create only cache archive
|
| 15 |
+
python scripts/create_archives.py --cache-only
|
| 16 |
+
|
| 17 |
+
# Show archive info without creating
|
| 18 |
+
python scripts/create_archives.py --info
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import logging
|
| 23 |
+
import tarfile
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Paths
|
| 34 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 35 |
+
CATALOG_DIR = PROJECT_ROOT / "data" / "catalog"
|
| 36 |
+
CACHE_DIR = PROJECT_ROOT / "cached_results"
|
| 37 |
+
CATALOG_ARCHIVE = PROJECT_ROOT / "data" / "catalog.tar.gz"
|
| 38 |
+
CACHE_ARCHIVE = PROJECT_ROOT / "data" / "cache.tar.gz"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_dir_size(path: Path) -> int:
|
| 42 |
+
"""Get total size of directory in bytes."""
|
| 43 |
+
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_file_count(path: Path) -> int:
|
| 47 |
+
"""Get total number of files in directory."""
|
| 48 |
+
return sum(1 for f in path.rglob("*") if f.is_file())
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def format_size(size_bytes: int) -> str:
|
| 52 |
+
"""Format size in human-readable format."""
|
| 53 |
+
for unit in ["B", "KB", "MB", "GB"]:
|
| 54 |
+
if size_bytes < 1024:
|
| 55 |
+
return f"{size_bytes:.1f} {unit}"
|
| 56 |
+
size_bytes /= 1024
|
| 57 |
+
return f"{size_bytes:.1f} TB"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def create_archive(source_dir: Path, archive_path: Path) -> None:
|
| 61 |
+
"""Create a tar.gz archive from a directory.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
source_dir: Directory to archive
|
| 65 |
+
archive_path: Output archive path
|
| 66 |
+
"""
|
| 67 |
+
if not source_dir.exists():
|
| 68 |
+
logger.error(f"Source directory not found: {source_dir}")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
source_size = get_dir_size(source_dir)
|
| 72 |
+
file_count = get_file_count(source_dir)
|
| 73 |
+
logger.info(f"Archiving {source_dir.name}/")
|
| 74 |
+
logger.info(f" Source: {format_size(source_size)} ({file_count} files)")
|
| 75 |
+
|
| 76 |
+
# Create archive
|
| 77 |
+
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
with tarfile.open(archive_path, "w:gz") as tar:
|
| 80 |
+
# Add directory with its name as the archive root
|
| 81 |
+
tar.add(source_dir, arcname=source_dir.name)
|
| 82 |
+
|
| 83 |
+
archive_size = archive_path.stat().st_size
|
| 84 |
+
compression_ratio = (1 - archive_size / source_size) * 100 if source_size > 0 else 0
|
| 85 |
+
|
| 86 |
+
logger.info(f" Archive: {format_size(archive_size)}")
|
| 87 |
+
logger.info(f" Compression: {compression_ratio:.1f}% reduction")
|
| 88 |
+
logger.info(f" Created: {archive_path}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def show_info() -> None:
|
| 92 |
+
"""Show information about directories and existing archives."""
|
| 93 |
+
print("\n=== Directory Info ===")
|
| 94 |
+
|
| 95 |
+
for name, path in [("Catalog", CATALOG_DIR), ("Cache", CACHE_DIR)]:
|
| 96 |
+
if path.exists():
|
| 97 |
+
size = get_dir_size(path)
|
| 98 |
+
count = get_file_count(path)
|
| 99 |
+
print(f"{name}: {format_size(size)} ({count} files)")
|
| 100 |
+
else:
|
| 101 |
+
print(f"{name}: not found")
|
| 102 |
+
|
| 103 |
+
print("\n=== Archive Info ===")
|
| 104 |
+
|
| 105 |
+
for name, path in [("Catalog", CATALOG_ARCHIVE), ("Cache", CACHE_ARCHIVE)]:
|
| 106 |
+
if path.exists():
|
| 107 |
+
size = path.stat().st_size
|
| 108 |
+
print(f"{name}: {format_size(size)}")
|
| 109 |
+
else:
|
| 110 |
+
print(f"{name}: not created")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
parser = argparse.ArgumentParser(
|
| 115 |
+
description="Create compressed archives of catalog and cache data"
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--catalog-only",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Only create catalog archive",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--cache-only",
|
| 124 |
+
action="store_true",
|
| 125 |
+
help="Only create cache archive",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--info",
|
| 129 |
+
action="store_true",
|
| 130 |
+
help="Show info about directories and archives",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
args = parser.parse_args()
|
| 134 |
+
|
| 135 |
+
if args.info:
|
| 136 |
+
show_info()
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Determine what to archive
|
| 140 |
+
do_catalog = not args.cache_only
|
| 141 |
+
do_cache = not args.catalog_only
|
| 142 |
+
|
| 143 |
+
if do_catalog:
|
| 144 |
+
create_archive(CATALOG_DIR, CATALOG_ARCHIVE)
|
| 145 |
+
|
| 146 |
+
if do_cache:
|
| 147 |
+
create_archive(CACHE_DIR, CACHE_ARCHIVE)
|
| 148 |
+
|
| 149 |
+
print("\n=== Summary ===")
|
| 150 |
+
show_info()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
scripts/precompute_cache.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Pre-compute pipeline results for all example images.
|
| 3 |
+
|
| 4 |
+
This script runs the full snow leopard identification pipeline on all example images
|
| 5 |
+
with all available feature extractors, caching the results for instant display
|
| 6 |
+
in the Gradio app.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# Process all example images with all extractors
|
| 10 |
+
python scripts/precompute_cache.py
|
| 11 |
+
|
| 12 |
+
# Process specific images
|
| 13 |
+
python scripts/precompute_cache.py --images IMG_001.jpg IMG_002.jpg
|
| 14 |
+
|
| 15 |
+
# Process with specific extractors only
|
| 16 |
+
python scripts/precompute_cache.py --extractors sift superpoint
|
| 17 |
+
|
| 18 |
+
# Clear cache and regenerate all
|
| 19 |
+
python scripts/precompute_cache.py --clear
|
| 20 |
+
|
| 21 |
+
# Show cache summary
|
| 22 |
+
python scripts/precompute_cache.py --summary
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import logging
|
| 27 |
+
import sys
|
| 28 |
+
import tempfile
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
from PIL import Image
|
| 35 |
+
|
| 36 |
+
# Add project root to path for imports
|
| 37 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 38 |
+
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
| 39 |
+
|
| 40 |
+
from snowleopard_reid.cache import (
|
| 41 |
+
CACHE_DIR,
|
| 42 |
+
clear_cache,
|
| 43 |
+
extract_location_body_part_from_filepath,
|
| 44 |
+
get_cache_dir,
|
| 45 |
+
get_cache_summary,
|
| 46 |
+
)
|
| 47 |
+
from snowleopard_reid.pipeline.stages import (
|
| 48 |
+
run_feature_extraction_stage,
|
| 49 |
+
run_matching_stage,
|
| 50 |
+
run_preprocess_stage,
|
| 51 |
+
run_segmentation_stage,
|
| 52 |
+
select_best_mask,
|
| 53 |
+
)
|
| 54 |
+
from snowleopard_reid.pipeline.stages.segmentation import (
|
| 55 |
+
load_gdino_model,
|
| 56 |
+
load_sam_predictor,
|
| 57 |
+
)
|
| 58 |
+
from snowleopard_reid.visualization import draw_keypoints_overlay
|
| 59 |
+
|
| 60 |
+
# Configure logging
|
| 61 |
+
logging.basicConfig(
|
| 62 |
+
level=logging.INFO,
|
| 63 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 64 |
+
)
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
# Configuration
|
| 68 |
+
CATALOG_ROOT = PROJECT_ROOT / "data" / "catalog"
|
| 69 |
+
SAM_CHECKPOINT_DIR = PROJECT_ROOT / "data" / "models"
|
| 70 |
+
SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth"
|
| 71 |
+
EXAMPLES_DIR = PROJECT_ROOT / "data" / "examples"
|
| 72 |
+
GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
|
| 73 |
+
TEXT_PROMPT = "a snow leopard."
|
| 74 |
+
SAM_MODEL_TYPE = "vit_l"
|
| 75 |
+
# Set very high to get ALL matches (will be limited by catalog size)
|
| 76 |
+
TOP_K_ALL = 1000
|
| 77 |
+
# Default top_k for display
|
| 78 |
+
TOP_K_DEFAULT = 5
|
| 79 |
+
|
| 80 |
+
# All available extractors
|
| 81 |
+
ALL_EXTRACTORS = ["sift", "superpoint", "disk", "aliked"]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def ensure_sam_model() -> Path:
|
| 85 |
+
"""Download SAM HQ model if not present.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Path to the SAM HQ checkpoint file
|
| 89 |
+
"""
|
| 90 |
+
from huggingface_hub import hf_hub_download
|
| 91 |
+
|
| 92 |
+
sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME
|
| 93 |
+
if not sam_path.exists():
|
| 94 |
+
logger.info("Downloading SAM HQ model (1.6GB)...")
|
| 95 |
+
SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
hf_hub_download(
|
| 97 |
+
repo_id="lkeab/hq-sam",
|
| 98 |
+
filename=SAM_CHECKPOINT_NAME,
|
| 99 |
+
local_dir=SAM_CHECKPOINT_DIR,
|
| 100 |
+
)
|
| 101 |
+
logger.info("SAM HQ model downloaded successfully")
|
| 102 |
+
return sam_path
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_segmentation_viz(image_path: Path, mask: np.ndarray) -> Image.Image:
|
| 106 |
+
"""Create visualization of segmentation mask overlaid on image.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
image_path: Path to original image
|
| 110 |
+
mask: Binary segmentation mask
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
PIL Image with segmentation overlay
|
| 114 |
+
"""
|
| 115 |
+
img = cv2.imread(str(image_path))
|
| 116 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 117 |
+
|
| 118 |
+
# Resize mask to match image dimensions if needed
|
| 119 |
+
if mask.shape[:2] != img_rgb.shape[:2]:
|
| 120 |
+
mask_resized = cv2.resize(
|
| 121 |
+
mask.astype(np.uint8),
|
| 122 |
+
(img_rgb.shape[1], img_rgb.shape[0]),
|
| 123 |
+
interpolation=cv2.INTER_NEAREST,
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
mask_resized = mask
|
| 127 |
+
|
| 128 |
+
# Create colored overlay
|
| 129 |
+
overlay = img_rgb.copy()
|
| 130 |
+
overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region
|
| 131 |
+
|
| 132 |
+
# Blend
|
| 133 |
+
alpha = 0.4
|
| 134 |
+
blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0)
|
| 135 |
+
|
| 136 |
+
return Image.fromarray(blended)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def process_and_cache(
|
| 140 |
+
image_path: Path,
|
| 141 |
+
extractor: str,
|
| 142 |
+
gdino_processor,
|
| 143 |
+
gdino_model,
|
| 144 |
+
sam_predictor,
|
| 145 |
+
device: str,
|
| 146 |
+
) -> bool:
|
| 147 |
+
"""Run full pipeline and cache ALL results for one image/extractor combination.
|
| 148 |
+
|
| 149 |
+
This version caches ALL matches (not just top-k) with location/body_part
|
| 150 |
+
metadata, and stores NPZ pairwise data for on-demand visualization generation.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
image_path: Path to example image
|
| 154 |
+
extractor: Feature extractor to use
|
| 155 |
+
gdino_processor: Pre-loaded Grounding DINO processor
|
| 156 |
+
gdino_model: Pre-loaded Grounding DINO model
|
| 157 |
+
sam_predictor: Pre-loaded SAM HQ predictor
|
| 158 |
+
device: Device to run on ('cuda' or 'cpu')
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
True if successful, False otherwise
|
| 162 |
+
"""
|
| 163 |
+
logger.info(f"Processing {image_path.name} with {extractor.upper()}...")
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
# Create temporary directory for intermediate files
|
| 167 |
+
with tempfile.TemporaryDirectory(prefix="snowleopard_cache_") as temp_dir:
|
| 168 |
+
temp_dir = Path(temp_dir)
|
| 169 |
+
|
| 170 |
+
# ================================================================
|
| 171 |
+
# Stage 1: Segmentation (GDINO+SAM)
|
| 172 |
+
# ================================================================
|
| 173 |
+
logger.info(" Running GDINO+SAM segmentation...")
|
| 174 |
+
seg_stage = run_segmentation_stage(
|
| 175 |
+
image_path=image_path,
|
| 176 |
+
strategy="gdino_sam",
|
| 177 |
+
confidence_threshold=0.2,
|
| 178 |
+
device=device,
|
| 179 |
+
gdino_processor=gdino_processor,
|
| 180 |
+
gdino_model=gdino_model,
|
| 181 |
+
sam_predictor=sam_predictor,
|
| 182 |
+
text_prompt=TEXT_PROMPT,
|
| 183 |
+
box_threshold=0.30,
|
| 184 |
+
text_threshold=0.20,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
predictions = seg_stage["data"]["predictions"]
|
| 188 |
+
if not predictions:
|
| 189 |
+
logger.warning(f" No snow leopards detected in {image_path.name}")
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
# ================================================================
|
| 193 |
+
# Stage 2: Mask Selection
|
| 194 |
+
# ================================================================
|
| 195 |
+
logger.info(" Selecting best mask...")
|
| 196 |
+
selected_idx, selected_pred = select_best_mask(
|
| 197 |
+
predictions,
|
| 198 |
+
strategy="confidence_area",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Create segmentation visualization
|
| 202 |
+
segmentation_image = create_segmentation_viz(
|
| 203 |
+
image_path=image_path,
|
| 204 |
+
mask=selected_pred["mask"],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# ================================================================
|
| 208 |
+
# Stage 3: Preprocessing
|
| 209 |
+
# ================================================================
|
| 210 |
+
logger.info(" Preprocessing...")
|
| 211 |
+
prep_stage = run_preprocess_stage(
|
| 212 |
+
image_path=image_path,
|
| 213 |
+
mask=selected_pred["mask"],
|
| 214 |
+
padding=5,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
cropped_image = prep_stage["data"]["cropped_image"]
|
| 218 |
+
|
| 219 |
+
# Save cropped image for visualization functions
|
| 220 |
+
cropped_path = temp_dir / "cropped.jpg"
|
| 221 |
+
cropped_image.save(cropped_path)
|
| 222 |
+
|
| 223 |
+
# ================================================================
|
| 224 |
+
# Stage 4: Feature Extraction
|
| 225 |
+
# ================================================================
|
| 226 |
+
logger.info(f" Extracting features ({extractor.upper()})...")
|
| 227 |
+
feat_stage = run_feature_extraction_stage(
|
| 228 |
+
image=cropped_image,
|
| 229 |
+
extractor=extractor,
|
| 230 |
+
max_keypoints=2048,
|
| 231 |
+
device=device,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
query_features = feat_stage["data"]["features"]
|
| 235 |
+
|
| 236 |
+
# Create keypoints visualization
|
| 237 |
+
query_kpts = query_features["keypoints"].cpu().numpy()
|
| 238 |
+
keypoints_image = draw_keypoints_overlay(
|
| 239 |
+
image_path=cropped_path,
|
| 240 |
+
keypoints=query_kpts,
|
| 241 |
+
max_keypoints=500,
|
| 242 |
+
color="blue",
|
| 243 |
+
ps=10,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# ================================================================
|
| 247 |
+
# Stage 5: Matching - Get ALL matches
|
| 248 |
+
# ================================================================
|
| 249 |
+
logger.info(" Matching against catalog (ALL matches)...")
|
| 250 |
+
temp_pairwise_dir = temp_dir / "pairwise"
|
| 251 |
+
temp_pairwise_dir.mkdir(exist_ok=True)
|
| 252 |
+
|
| 253 |
+
match_stage = run_matching_stage(
|
| 254 |
+
query_features=query_features,
|
| 255 |
+
catalog_path=CATALOG_ROOT,
|
| 256 |
+
top_k=TOP_K_ALL, # Get ALL matches
|
| 257 |
+
extractor=extractor,
|
| 258 |
+
device=device,
|
| 259 |
+
query_image_path=str(cropped_path),
|
| 260 |
+
pairwise_output_dir=temp_pairwise_dir,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
matches = match_stage["data"]["matches"]
|
| 264 |
+
|
| 265 |
+
if not matches:
|
| 266 |
+
logger.warning(f" No matches found for {image_path.name}")
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
logger.info(f" Found {len(matches)} matches")
|
| 270 |
+
|
| 271 |
+
# ================================================================
|
| 272 |
+
# Enrich matches with location/body_part
|
| 273 |
+
# ================================================================
|
| 274 |
+
logger.info(" Adding location/body_part metadata...")
|
| 275 |
+
for match in matches:
|
| 276 |
+
location, body_part = extract_location_body_part_from_filepath(
|
| 277 |
+
match["filepath"]
|
| 278 |
+
)
|
| 279 |
+
match["location"] = location
|
| 280 |
+
match["body_part"] = body_part
|
| 281 |
+
|
| 282 |
+
# ================================================================
|
| 283 |
+
# Set up cache directory
|
| 284 |
+
# ================================================================
|
| 285 |
+
cache_dir = get_cache_dir(image_path, extractor)
|
| 286 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
pairwise_dir = cache_dir / "pairwise"
|
| 288 |
+
pairwise_dir.mkdir(exist_ok=True)
|
| 289 |
+
|
| 290 |
+
# ================================================================
|
| 291 |
+
# Copy NPZ files with catalog_id naming (not rank-based)
|
| 292 |
+
# ================================================================
|
| 293 |
+
logger.info(" Copying NPZ pairwise data...")
|
| 294 |
+
npz_count = 0
|
| 295 |
+
for match in matches:
|
| 296 |
+
catalog_id = match["catalog_id"]
|
| 297 |
+
rank = match["rank"]
|
| 298 |
+
|
| 299 |
+
# Source NPZ (rank-based naming from matching stage)
|
| 300 |
+
src_npz = temp_pairwise_dir / f"rank_{rank:02d}_{catalog_id}.npz"
|
| 301 |
+
|
| 302 |
+
# Destination NPZ (catalog_id naming for cache)
|
| 303 |
+
dst_npz = pairwise_dir / f"{catalog_id}.npz"
|
| 304 |
+
|
| 305 |
+
if src_npz.exists():
|
| 306 |
+
import shutil
|
| 307 |
+
shutil.copy2(src_npz, dst_npz)
|
| 308 |
+
npz_count += 1
|
| 309 |
+
|
| 310 |
+
logger.info(f" Copied {npz_count} NPZ files")
|
| 311 |
+
|
| 312 |
+
# ================================================================
|
| 313 |
+
# Build Predictions Dict (v2.0 format with all_matches)
|
| 314 |
+
# ================================================================
|
| 315 |
+
predictions_dict = {
|
| 316 |
+
"format_version": "2.0",
|
| 317 |
+
"query_image": str(image_path),
|
| 318 |
+
"extractor": extractor,
|
| 319 |
+
"pipeline": {
|
| 320 |
+
"segmentation": {
|
| 321 |
+
"strategy": "gdino_sam",
|
| 322 |
+
"num_predictions": len(predictions),
|
| 323 |
+
"selected_idx": selected_idx,
|
| 324 |
+
"confidence": float(selected_pred["confidence"]),
|
| 325 |
+
},
|
| 326 |
+
"preprocessing": {
|
| 327 |
+
"padding": prep_stage["config"]["padding"],
|
| 328 |
+
},
|
| 329 |
+
"features": {
|
| 330 |
+
"num_keypoints": int(feat_stage["metrics"]["num_keypoints"]),
|
| 331 |
+
"extractor": extractor,
|
| 332 |
+
"max_keypoints": 2048,
|
| 333 |
+
},
|
| 334 |
+
"matching": {
|
| 335 |
+
"num_catalog_images": match_stage["metrics"]["num_catalog_images"],
|
| 336 |
+
"num_successful_matches": match_stage["metrics"]["num_successful_matches"],
|
| 337 |
+
},
|
| 338 |
+
},
|
| 339 |
+
"all_matches": matches, # ALL matches with location/body_part
|
| 340 |
+
"top_k": TOP_K_DEFAULT,
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
# ================================================================
|
| 344 |
+
# Save Cache (predictions.json + visualization images)
|
| 345 |
+
# ================================================================
|
| 346 |
+
logger.info(" Saving to cache...")
|
| 347 |
+
|
| 348 |
+
# Save predictions JSON
|
| 349 |
+
import json
|
| 350 |
+
predictions_file = cache_dir / "predictions.json"
|
| 351 |
+
with open(predictions_file, "w") as f:
|
| 352 |
+
json.dump(predictions_dict, f, indent=2)
|
| 353 |
+
|
| 354 |
+
# Save visualization images
|
| 355 |
+
segmentation_image.save(cache_dir / "segmentation.png")
|
| 356 |
+
cropped_image.save(cache_dir / "cropped.png")
|
| 357 |
+
keypoints_image.save(cache_dir / "keypoints.png")
|
| 358 |
+
|
| 359 |
+
# Log cache size
|
| 360 |
+
cache_size = sum(
|
| 361 |
+
f.stat().st_size for f in cache_dir.rglob("*") if f.is_file()
|
| 362 |
+
)
|
| 363 |
+
logger.info(
|
| 364 |
+
f" Cached: {cache_dir.name} ({cache_size / 1024 / 1024:.2f} MB)"
|
| 365 |
+
)
|
| 366 |
+
logger.info(f" {len(matches)} matches, {npz_count} NPZ files")
|
| 367 |
+
|
| 368 |
+
return True
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
logger.error(f" Failed: {e}", exc_info=True)
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def main():
|
| 376 |
+
parser = argparse.ArgumentParser(
|
| 377 |
+
description="Pre-compute pipeline results for example images"
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--images",
|
| 381 |
+
nargs="+",
|
| 382 |
+
help="Specific image filenames to process (default: all in examples/)",
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--extractors",
|
| 386 |
+
nargs="+",
|
| 387 |
+
choices=ALL_EXTRACTORS,
|
| 388 |
+
default=ALL_EXTRACTORS,
|
| 389 |
+
help="Feature extractors to use (default: all)",
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--clear",
|
| 393 |
+
action="store_true",
|
| 394 |
+
help="Clear all cached results before processing",
|
| 395 |
+
)
|
| 396 |
+
parser.add_argument(
|
| 397 |
+
"--summary",
|
| 398 |
+
action="store_true",
|
| 399 |
+
help="Show cache summary and exit",
|
| 400 |
+
)
|
| 401 |
+
parser.add_argument(
|
| 402 |
+
"--device",
|
| 403 |
+
choices=["cpu", "cuda"],
|
| 404 |
+
default=None,
|
| 405 |
+
help="Device to run on (default: auto-detect)",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
args = parser.parse_args()
|
| 409 |
+
|
| 410 |
+
# Show summary and exit
|
| 411 |
+
if args.summary:
|
| 412 |
+
summary = get_cache_summary()
|
| 413 |
+
print("\n=== Cache Summary ===")
|
| 414 |
+
print(f"Total cached: {summary['total_cached']} items")
|
| 415 |
+
print(f"Total size: {summary['total_size_mb']:.2f} MB")
|
| 416 |
+
print("\nCached items:")
|
| 417 |
+
for item in summary["cached_items"]:
|
| 418 |
+
print(f" - {item['image_stem']} ({item['extractor']}): {item['size_mb']:.2f} MB")
|
| 419 |
+
return
|
| 420 |
+
|
| 421 |
+
# Clear cache if requested
|
| 422 |
+
if args.clear:
|
| 423 |
+
logger.info("Clearing cache...")
|
| 424 |
+
clear_cache()
|
| 425 |
+
logger.info("Cache cleared")
|
| 426 |
+
|
| 427 |
+
# Determine device
|
| 428 |
+
if args.device:
|
| 429 |
+
device = args.device
|
| 430 |
+
else:
|
| 431 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 432 |
+
|
| 433 |
+
logger.info(f"Using device: {device}")
|
| 434 |
+
if device == "cuda":
|
| 435 |
+
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 436 |
+
|
| 437 |
+
# Find example images
|
| 438 |
+
if args.images:
|
| 439 |
+
image_paths = [EXAMPLES_DIR / img for img in args.images]
|
| 440 |
+
# Filter existing files
|
| 441 |
+
image_paths = [p for p in image_paths if p.exists()]
|
| 442 |
+
if not image_paths:
|
| 443 |
+
logger.error("No valid image paths found")
|
| 444 |
+
sys.exit(1)
|
| 445 |
+
else:
|
| 446 |
+
image_paths = (
|
| 447 |
+
list(EXAMPLES_DIR.glob("*.jpg"))
|
| 448 |
+
+ list(EXAMPLES_DIR.glob("*.JPG"))
|
| 449 |
+
+ list(EXAMPLES_DIR.glob("*.png"))
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if not image_paths:
|
| 453 |
+
logger.error(f"No example images found in {EXAMPLES_DIR}")
|
| 454 |
+
sys.exit(1)
|
| 455 |
+
|
| 456 |
+
logger.info(f"Found {len(image_paths)} example images")
|
| 457 |
+
logger.info(f"Extractors: {', '.join(args.extractors)}")
|
| 458 |
+
|
| 459 |
+
# Ensure SAM model is downloaded
|
| 460 |
+
logger.info("Checking for SAM HQ model...")
|
| 461 |
+
sam_path = ensure_sam_model()
|
| 462 |
+
|
| 463 |
+
# Load GDINO model once
|
| 464 |
+
logger.info(f"Loading Grounding DINO model: {GDINO_MODEL_ID}...")
|
| 465 |
+
gdino_processor, gdino_model = load_gdino_model(
|
| 466 |
+
model_id=GDINO_MODEL_ID,
|
| 467 |
+
device=device,
|
| 468 |
+
)
|
| 469 |
+
logger.info("Grounding DINO model loaded")
|
| 470 |
+
|
| 471 |
+
# Load SAM HQ model once
|
| 472 |
+
logger.info(f"Loading SAM HQ model from {sam_path}...")
|
| 473 |
+
sam_predictor = load_sam_predictor(
|
| 474 |
+
checkpoint_path=sam_path,
|
| 475 |
+
model_type=SAM_MODEL_TYPE,
|
| 476 |
+
device=device,
|
| 477 |
+
)
|
| 478 |
+
logger.info("SAM HQ model loaded")
|
| 479 |
+
|
| 480 |
+
# Process all combinations
|
| 481 |
+
total = len(image_paths) * len(args.extractors)
|
| 482 |
+
success = 0
|
| 483 |
+
failed = 0
|
| 484 |
+
|
| 485 |
+
for i, image_path in enumerate(image_paths):
|
| 486 |
+
for j, extractor in enumerate(args.extractors):
|
| 487 |
+
current = i * len(args.extractors) + j + 1
|
| 488 |
+
logger.info(f"\n[{current}/{total}] Processing...")
|
| 489 |
+
|
| 490 |
+
if process_and_cache(
|
| 491 |
+
image_path=image_path,
|
| 492 |
+
extractor=extractor,
|
| 493 |
+
gdino_processor=gdino_processor,
|
| 494 |
+
gdino_model=gdino_model,
|
| 495 |
+
sam_predictor=sam_predictor,
|
| 496 |
+
device=device,
|
| 497 |
+
):
|
| 498 |
+
success += 1
|
| 499 |
+
else:
|
| 500 |
+
failed += 1
|
| 501 |
+
|
| 502 |
+
# Final summary
|
| 503 |
+
logger.info("\n" + "=" * 50)
|
| 504 |
+
logger.info("PRECOMPUTATION COMPLETE")
|
| 505 |
+
logger.info("=" * 50)
|
| 506 |
+
logger.info(f"Success: {success}/{total}")
|
| 507 |
+
logger.info(f"Failed: {failed}/{total}")
|
| 508 |
+
|
| 509 |
+
# Show cache summary
|
| 510 |
+
summary = get_cache_summary()
|
| 511 |
+
logger.info(f"Total cache size: {summary['total_size_mb']:.2f} MB")
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
if __name__ == "__main__":
|
| 515 |
+
main()
|
src/snowleopard_reid/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Snow Leopard Re-Identification Package."""
|
| 2 |
+
|
| 3 |
+
from snowleopard_reid.cache import (
|
| 4 |
+
clear_cache,
|
| 5 |
+
get_cache_summary,
|
| 6 |
+
is_cached,
|
| 7 |
+
load_cached_match_visualizations,
|
| 8 |
+
load_cached_results,
|
| 9 |
+
save_cache_results,
|
| 10 |
+
)
|
| 11 |
+
from snowleopard_reid.images import resize_image_if_needed
|
| 12 |
+
from snowleopard_reid.utils import get_device
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"clear_cache",
|
| 16 |
+
"get_cache_summary",
|
| 17 |
+
"get_device",
|
| 18 |
+
"is_cached",
|
| 19 |
+
"load_cached_match_visualizations",
|
| 20 |
+
"load_cached_results",
|
| 21 |
+
"resize_image_if_needed",
|
| 22 |
+
"save_cache_results",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None:
|
| 27 |
+
print("Hello from snowleopard-reid!")
|
src/snowleopard_reid/cache.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cache utilities for precomputed pipeline results.
|
| 2 |
+
|
| 3 |
+
This module provides functions for loading and saving cached pipeline results,
|
| 4 |
+
enabling instant display of results for example images without running the
|
| 5 |
+
expensive pipeline (GDINO+SAM segmentation, feature extraction, matching) on CPU.
|
| 6 |
+
|
| 7 |
+
Cache Structure (v2.0 - supports filtering):
|
| 8 |
+
cached_results/
|
| 9 |
+
├── {image_stem}_{extractor}/
|
| 10 |
+
│ ├── predictions.json # ALL matches with location/body_part
|
| 11 |
+
│ ├── segmentation.png # Segmentation visualization
|
| 12 |
+
│ ├── cropped.png # Cropped snow leopard image
|
| 13 |
+
│ ├── keypoints.png # Extracted keypoints visualization
|
| 14 |
+
│ └── pairwise/
|
| 15 |
+
│ ├── {catalog_id}.npz # NPZ data for ALL matches
|
| 16 |
+
│ └── ... # (visualizations generated on-demand)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import copy
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from snowleopard_reid.visualization import (
|
| 28 |
+
draw_matched_keypoints,
|
| 29 |
+
draw_side_by_side_comparison,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Cache directory relative to project root
|
| 35 |
+
CACHE_DIR = Path("cached_results")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_cache_key(image_path: Path | str, extractor: str) -> str:
|
| 39 |
+
"""Generate cache key from image stem and extractor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
image_path: Path to the query image
|
| 43 |
+
extractor: Feature extractor name (e.g., 'sift', 'superpoint')
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Cache key string in format "{image_stem}_{extractor}"
|
| 47 |
+
"""
|
| 48 |
+
image_path = Path(image_path)
|
| 49 |
+
return f"{image_path.stem}_{extractor}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_cache_dir(image_path: Path | str, extractor: str) -> Path:
|
| 53 |
+
"""Get cache directory for an image/extractor combination.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
image_path: Path to the query image
|
| 57 |
+
extractor: Feature extractor name
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Path to the cache directory
|
| 61 |
+
"""
|
| 62 |
+
return CACHE_DIR / get_cache_key(image_path, extractor)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def is_cached(image_path: Path | str, extractor: str) -> bool:
|
| 66 |
+
"""Check if results are cached for this image/extractor combination.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
image_path: Path to the query image
|
| 70 |
+
extractor: Feature extractor name
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
True if all required cache files exist
|
| 74 |
+
"""
|
| 75 |
+
cache_dir = get_cache_dir(image_path, extractor)
|
| 76 |
+
predictions_file = cache_dir / "predictions.json"
|
| 77 |
+
|
| 78 |
+
if not predictions_file.exists():
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
# Check for required visualization files
|
| 82 |
+
required_files = [
|
| 83 |
+
"segmentation.png",
|
| 84 |
+
"cropped.png",
|
| 85 |
+
"keypoints.png",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
for filename in required_files:
|
| 89 |
+
if not (cache_dir / filename).exists():
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_cached_results(image_path: Path | str, extractor: str) -> dict:
|
| 96 |
+
"""Load all cached results for an image/extractor combination.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image_path: Path to the query image
|
| 100 |
+
extractor: Feature extractor name
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Dictionary containing:
|
| 104 |
+
- predictions: Full pipeline predictions dict
|
| 105 |
+
- segmentation_image: PIL Image of segmentation overlay
|
| 106 |
+
- cropped_image: PIL Image of cropped snow leopard
|
| 107 |
+
- keypoints_image: PIL Image of extracted keypoints
|
| 108 |
+
- pairwise_dir: Path to directory with match visualizations
|
| 109 |
+
|
| 110 |
+
Raises:
|
| 111 |
+
FileNotFoundError: If cache files don't exist
|
| 112 |
+
"""
|
| 113 |
+
cache_dir = get_cache_dir(image_path, extractor)
|
| 114 |
+
|
| 115 |
+
if not cache_dir.exists():
|
| 116 |
+
raise FileNotFoundError(f"Cache directory not found: {cache_dir}")
|
| 117 |
+
|
| 118 |
+
predictions_file = cache_dir / "predictions.json"
|
| 119 |
+
if not predictions_file.exists():
|
| 120 |
+
raise FileNotFoundError(f"Predictions file not found: {predictions_file}")
|
| 121 |
+
|
| 122 |
+
# Load predictions JSON
|
| 123 |
+
with open(predictions_file) as f:
|
| 124 |
+
predictions = json.load(f)
|
| 125 |
+
|
| 126 |
+
# Load visualization images
|
| 127 |
+
segmentation_image = Image.open(cache_dir / "segmentation.png")
|
| 128 |
+
cropped_image = Image.open(cache_dir / "cropped.png")
|
| 129 |
+
keypoints_image = Image.open(cache_dir / "keypoints.png")
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
"predictions": predictions,
|
| 133 |
+
"segmentation_image": segmentation_image,
|
| 134 |
+
"cropped_image": cropped_image,
|
| 135 |
+
"keypoints_image": keypoints_image,
|
| 136 |
+
"pairwise_dir": cache_dir / "pairwise",
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def load_cached_match_visualizations(
|
| 141 |
+
pairwise_dir: Path,
|
| 142 |
+
matches: list[dict],
|
| 143 |
+
) -> tuple[dict, dict]:
|
| 144 |
+
"""Load cached match and clean comparison visualizations.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
pairwise_dir: Path to pairwise visualizations directory
|
| 148 |
+
matches: List of match dictionaries with rank and catalog_id
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Tuple of (match_visualizations, clean_comparison_visualizations)
|
| 152 |
+
Both are dicts mapping rank -> PIL Image
|
| 153 |
+
"""
|
| 154 |
+
match_visualizations = {}
|
| 155 |
+
clean_comparison_visualizations = {}
|
| 156 |
+
|
| 157 |
+
for match in matches:
|
| 158 |
+
rank = match["rank"]
|
| 159 |
+
catalog_id = match["catalog_id"]
|
| 160 |
+
|
| 161 |
+
# Load match visualization
|
| 162 |
+
match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png"
|
| 163 |
+
if match_path.exists():
|
| 164 |
+
match_visualizations[rank] = Image.open(match_path)
|
| 165 |
+
|
| 166 |
+
# Load clean comparison visualization
|
| 167 |
+
clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png"
|
| 168 |
+
if clean_path.exists():
|
| 169 |
+
clean_comparison_visualizations[rank] = Image.open(clean_path)
|
| 170 |
+
|
| 171 |
+
return match_visualizations, clean_comparison_visualizations
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def save_cache_results(
|
| 175 |
+
image_path: Path | str,
|
| 176 |
+
extractor: str,
|
| 177 |
+
predictions: dict,
|
| 178 |
+
segmentation_image: Image.Image,
|
| 179 |
+
cropped_image: Image.Image,
|
| 180 |
+
keypoints_image: Image.Image,
|
| 181 |
+
match_visualizations: dict[int, Image.Image],
|
| 182 |
+
clean_comparison_visualizations: dict[int, Image.Image],
|
| 183 |
+
matches: list[dict],
|
| 184 |
+
) -> Path:
|
| 185 |
+
"""Save pipeline results to cache.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
image_path: Path to the original query image
|
| 189 |
+
extractor: Feature extractor name
|
| 190 |
+
predictions: Full pipeline predictions dictionary
|
| 191 |
+
segmentation_image: PIL Image of segmentation overlay
|
| 192 |
+
cropped_image: PIL Image of cropped snow leopard
|
| 193 |
+
keypoints_image: PIL Image of extracted keypoints
|
| 194 |
+
match_visualizations: Dict mapping rank -> match visualization PIL Image
|
| 195 |
+
clean_comparison_visualizations: Dict mapping rank -> clean comparison PIL Image
|
| 196 |
+
matches: List of match dictionaries with rank and catalog_id
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Path to the cache directory
|
| 200 |
+
"""
|
| 201 |
+
cache_dir = get_cache_dir(image_path, extractor)
|
| 202 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 203 |
+
|
| 204 |
+
# Save predictions JSON
|
| 205 |
+
predictions_file = cache_dir / "predictions.json"
|
| 206 |
+
with open(predictions_file, "w") as f:
|
| 207 |
+
json.dump(predictions, f, indent=2)
|
| 208 |
+
logger.info(f"Saved predictions: {predictions_file}")
|
| 209 |
+
|
| 210 |
+
# Save visualization images
|
| 211 |
+
segmentation_image.save(cache_dir / "segmentation.png")
|
| 212 |
+
cropped_image.save(cache_dir / "cropped.png")
|
| 213 |
+
keypoints_image.save(cache_dir / "keypoints.png")
|
| 214 |
+
logger.info(f"Saved visualization images to {cache_dir}")
|
| 215 |
+
|
| 216 |
+
# Save pairwise match visualizations
|
| 217 |
+
pairwise_dir = cache_dir / "pairwise"
|
| 218 |
+
pairwise_dir.mkdir(exist_ok=True)
|
| 219 |
+
|
| 220 |
+
for match in matches:
|
| 221 |
+
rank = match["rank"]
|
| 222 |
+
catalog_id = match["catalog_id"]
|
| 223 |
+
|
| 224 |
+
# Save match visualization
|
| 225 |
+
if rank in match_visualizations:
|
| 226 |
+
match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png"
|
| 227 |
+
match_visualizations[rank].save(match_path)
|
| 228 |
+
|
| 229 |
+
# Save clean comparison visualization
|
| 230 |
+
if rank in clean_comparison_visualizations:
|
| 231 |
+
clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png"
|
| 232 |
+
clean_comparison_visualizations[rank].save(clean_path)
|
| 233 |
+
|
| 234 |
+
logger.info(f"Saved {len(match_visualizations)} pairwise visualizations")
|
| 235 |
+
|
| 236 |
+
return cache_dir
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def clear_cache(image_path: Path | str = None, extractor: str = None) -> None:
|
| 240 |
+
"""Clear cache directory.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
image_path: If provided, only clear cache for this image
|
| 244 |
+
extractor: If provided with image_path, only clear specific cache
|
| 245 |
+
"""
|
| 246 |
+
import shutil
|
| 247 |
+
|
| 248 |
+
if image_path and extractor:
|
| 249 |
+
# Clear specific cache
|
| 250 |
+
cache_dir = get_cache_dir(image_path, extractor)
|
| 251 |
+
if cache_dir.exists():
|
| 252 |
+
shutil.rmtree(cache_dir)
|
| 253 |
+
logger.info(f"Cleared cache: {cache_dir}")
|
| 254 |
+
elif CACHE_DIR.exists():
|
| 255 |
+
# Clear all caches
|
| 256 |
+
shutil.rmtree(CACHE_DIR)
|
| 257 |
+
logger.info(f"Cleared all caches: {CACHE_DIR}")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_cache_summary() -> dict:
|
| 261 |
+
"""Get summary of cached results.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Dictionary with cache statistics
|
| 265 |
+
"""
|
| 266 |
+
if not CACHE_DIR.exists():
|
| 267 |
+
return {"total_cached": 0, "total_size_mb": 0, "cached_items": []}
|
| 268 |
+
|
| 269 |
+
cached_items = []
|
| 270 |
+
total_size = 0
|
| 271 |
+
|
| 272 |
+
for cache_dir in CACHE_DIR.iterdir():
|
| 273 |
+
if cache_dir.is_dir():
|
| 274 |
+
# Calculate size
|
| 275 |
+
size = sum(f.stat().st_size for f in cache_dir.rglob("*") if f.is_file())
|
| 276 |
+
total_size += size
|
| 277 |
+
|
| 278 |
+
# Parse cache key
|
| 279 |
+
parts = cache_dir.name.rsplit("_", 1)
|
| 280 |
+
if len(parts) == 2:
|
| 281 |
+
image_stem, extractor = parts
|
| 282 |
+
else:
|
| 283 |
+
image_stem, extractor = cache_dir.name, "unknown"
|
| 284 |
+
|
| 285 |
+
cached_items.append({
|
| 286 |
+
"image_stem": image_stem,
|
| 287 |
+
"extractor": extractor,
|
| 288 |
+
"size_mb": size / (1024 * 1024),
|
| 289 |
+
"path": str(cache_dir),
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
return {
|
| 293 |
+
"total_cached": len(cached_items),
|
| 294 |
+
"total_size_mb": total_size / (1024 * 1024),
|
| 295 |
+
"cached_items": cached_items,
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def filter_cached_matches(
|
| 300 |
+
all_matches: list[dict],
|
| 301 |
+
filter_locations: list[str] | None = None,
|
| 302 |
+
filter_body_parts: list[str] | None = None,
|
| 303 |
+
top_k: int = 5,
|
| 304 |
+
) -> list[dict]:
|
| 305 |
+
"""Filter cached matches by location/body_part and return top-k.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
all_matches: List of all cached match dictionaries
|
| 309 |
+
filter_locations: List of locations to filter by (e.g., ["skycrest_valley"])
|
| 310 |
+
filter_body_parts: List of body parts to filter by (e.g., ["head", "right_flank"])
|
| 311 |
+
top_k: Number of top matches to return after filtering
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
List of filtered and re-ranked match dictionaries
|
| 315 |
+
"""
|
| 316 |
+
# Make a deep copy to avoid modifying the original
|
| 317 |
+
filtered = [copy.deepcopy(m) for m in all_matches]
|
| 318 |
+
|
| 319 |
+
if filter_locations:
|
| 320 |
+
filtered = [m for m in filtered if m.get("location") in filter_locations]
|
| 321 |
+
|
| 322 |
+
if filter_body_parts:
|
| 323 |
+
filtered = [m for m in filtered if m.get("body_part") in filter_body_parts]
|
| 324 |
+
|
| 325 |
+
# Re-sort by wasserstein (descending - higher is better)
|
| 326 |
+
filtered = sorted(filtered, key=lambda x: x.get("wasserstein", 0), reverse=True)
|
| 327 |
+
|
| 328 |
+
# Re-assign ranks for the filtered top-k
|
| 329 |
+
for i, match in enumerate(filtered[:top_k]):
|
| 330 |
+
match["rank"] = i + 1
|
| 331 |
+
|
| 332 |
+
return filtered[:top_k]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def generate_visualizations_from_npz(
|
| 336 |
+
pairwise_dir: Path,
|
| 337 |
+
matches: list[dict],
|
| 338 |
+
cropped_image_path: Path | str,
|
| 339 |
+
) -> tuple[dict, dict]:
|
| 340 |
+
"""Generate match visualizations on-demand from cached NPZ data.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
pairwise_dir: Path to directory containing NPZ pairwise data files
|
| 344 |
+
matches: List of filtered match dictionaries with catalog_id and filepath
|
| 345 |
+
cropped_image_path: Path to the cropped query image
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Tuple of (match_visualizations, clean_comparison_visualizations)
|
| 349 |
+
Both are dicts mapping rank -> PIL Image
|
| 350 |
+
"""
|
| 351 |
+
match_visualizations = {}
|
| 352 |
+
clean_comparison_visualizations = {}
|
| 353 |
+
|
| 354 |
+
cropped_image_path = Path(cropped_image_path)
|
| 355 |
+
|
| 356 |
+
for match in matches:
|
| 357 |
+
rank = match["rank"]
|
| 358 |
+
catalog_id = match["catalog_id"]
|
| 359 |
+
catalog_image_path = Path(match["filepath"])
|
| 360 |
+
|
| 361 |
+
# Look for NPZ file by catalog_id
|
| 362 |
+
npz_path = pairwise_dir / f"{catalog_id}.npz"
|
| 363 |
+
|
| 364 |
+
if npz_path.exists():
|
| 365 |
+
try:
|
| 366 |
+
pairwise_data = np.load(npz_path, allow_pickle=True)
|
| 367 |
+
|
| 368 |
+
# Generate matched keypoints visualization
|
| 369 |
+
match_viz = draw_matched_keypoints(
|
| 370 |
+
query_image_path=cropped_image_path,
|
| 371 |
+
catalog_image_path=catalog_image_path,
|
| 372 |
+
query_keypoints=pairwise_data["query_keypoints"],
|
| 373 |
+
catalog_keypoints=pairwise_data["catalog_keypoints"],
|
| 374 |
+
match_scores=pairwise_data["match_scores"],
|
| 375 |
+
max_matches=100,
|
| 376 |
+
)
|
| 377 |
+
match_visualizations[rank] = match_viz
|
| 378 |
+
|
| 379 |
+
# Generate clean side-by-side comparison
|
| 380 |
+
clean_viz = draw_side_by_side_comparison(
|
| 381 |
+
query_image_path=cropped_image_path,
|
| 382 |
+
catalog_image_path=catalog_image_path,
|
| 383 |
+
)
|
| 384 |
+
clean_comparison_visualizations[rank] = clean_viz
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.warning(
|
| 388 |
+
f"Failed to generate visualization for {catalog_id}: {e}"
|
| 389 |
+
)
|
| 390 |
+
else:
|
| 391 |
+
logger.warning(f"NPZ file not found for {catalog_id}: {npz_path}")
|
| 392 |
+
|
| 393 |
+
return match_visualizations, clean_comparison_visualizations
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def extract_location_body_part_from_filepath(filepath: str) -> tuple[str, str]:
|
| 397 |
+
"""Extract location and body_part from catalog image filepath.
|
| 398 |
+
|
| 399 |
+
Expected filepath format:
|
| 400 |
+
.../database/{location}/{individual}/images/{body_part}/{filename}
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
filepath: Path to catalog image
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
Tuple of (location, body_part)
|
| 407 |
+
"""
|
| 408 |
+
parts = Path(filepath).parts
|
| 409 |
+
|
| 410 |
+
# Find "database" in path and extract location (next part) and body_part
|
| 411 |
+
try:
|
| 412 |
+
db_idx = parts.index("database")
|
| 413 |
+
location = parts[db_idx + 1] if db_idx + 1 < len(parts) else "unknown"
|
| 414 |
+
|
| 415 |
+
# Find "images" in path and get body_part (next part)
|
| 416 |
+
img_idx = parts.index("images")
|
| 417 |
+
body_part = parts[img_idx + 1] if img_idx + 1 < len(parts) else "unknown"
|
| 418 |
+
|
| 419 |
+
return location, body_part
|
| 420 |
+
except (ValueError, IndexError):
|
| 421 |
+
return "unknown", "unknown"
|
src/snowleopard_reid/catalog/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Catalog module for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module provides utilities for loading and managing the snow leopard catalog,
|
| 4 |
+
including individual metadata and feature data.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .loader import (
|
| 8 |
+
get_all_catalog_features,
|
| 9 |
+
get_available_body_parts,
|
| 10 |
+
get_available_locations,
|
| 11 |
+
get_catalog_metadata_for_id,
|
| 12 |
+
get_filtered_catalog_features,
|
| 13 |
+
load_catalog_index,
|
| 14 |
+
load_leopard_metadata,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"load_catalog_index",
|
| 19 |
+
"load_leopard_metadata",
|
| 20 |
+
"get_all_catalog_features",
|
| 21 |
+
"get_filtered_catalog_features",
|
| 22 |
+
"get_available_locations",
|
| 23 |
+
"get_available_body_parts",
|
| 24 |
+
"get_catalog_metadata_for_id",
|
| 25 |
+
]
|
src/snowleopard_reid/catalog/loader.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for loading and managing the snow leopard catalog.
|
| 2 |
+
|
| 3 |
+
This module provides functions for loading catalog metadata, individual leopard
|
| 4 |
+
information, and catalog features for matching operations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_catalog_index(catalog_root: Path) -> dict:
|
| 14 |
+
"""Load the catalog index YAML file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Dictionary with catalog index data including:
|
| 21 |
+
- catalog_version: str
|
| 22 |
+
- feature_extractors: dict
|
| 23 |
+
- individuals: list
|
| 24 |
+
- statistics: dict
|
| 25 |
+
|
| 26 |
+
Raises:
|
| 27 |
+
FileNotFoundError: If catalog index file doesn't exist
|
| 28 |
+
yaml.YAMLError: If YAML parsing fails
|
| 29 |
+
"""
|
| 30 |
+
index_path = catalog_root / "catalog_index.yaml"
|
| 31 |
+
|
| 32 |
+
if not index_path.exists():
|
| 33 |
+
raise FileNotFoundError(f"Catalog index not found: {index_path}")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
with open(index_path) as f:
|
| 37 |
+
index = yaml.safe_load(f)
|
| 38 |
+
return index
|
| 39 |
+
except yaml.YAMLError as e:
|
| 40 |
+
raise yaml.YAMLError(f"Failed to parse catalog index: {e}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_leopard_metadata(metadata_path: Path) -> dict:
|
| 44 |
+
"""Load metadata YAML file for a specific leopard.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
metadata_path: Path to leopard metadata.yaml file
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Dictionary with leopard metadata including:
|
| 51 |
+
- individual_id: str
|
| 52 |
+
- leopard_name: str
|
| 53 |
+
- reference_images: list
|
| 54 |
+
- statistics: dict
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
FileNotFoundError: If metadata file doesn't exist
|
| 58 |
+
yaml.YAMLError: If YAML parsing fails
|
| 59 |
+
"""
|
| 60 |
+
if not metadata_path.exists():
|
| 61 |
+
raise FileNotFoundError(f"Leopard metadata not found: {metadata_path}")
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
with open(metadata_path) as f:
|
| 65 |
+
metadata = yaml.safe_load(f)
|
| 66 |
+
return metadata
|
| 67 |
+
except yaml.YAMLError as e:
|
| 68 |
+
raise yaml.YAMLError(f"Failed to parse leopard metadata: {e}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_all_catalog_features(
|
| 72 |
+
catalog_root: Path,
|
| 73 |
+
extractor: str = "sift",
|
| 74 |
+
) -> dict[str, dict[str, torch.Tensor]]:
|
| 75 |
+
"""Load all catalog features for a specific extractor.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
|
| 79 |
+
extractor: Feature extractor name (default: 'sift')
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Dictionary mapping catalog_id to feature dict:
|
| 83 |
+
{
|
| 84 |
+
"leopard1_2022_001": {
|
| 85 |
+
"keypoints": torch.Tensor,
|
| 86 |
+
"descriptors": torch.Tensor,
|
| 87 |
+
"scores": torch.Tensor,
|
| 88 |
+
...
|
| 89 |
+
},
|
| 90 |
+
...
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
Raises:
|
| 94 |
+
FileNotFoundError: If catalog doesn't exist
|
| 95 |
+
ValueError: If no features found for extractor
|
| 96 |
+
"""
|
| 97 |
+
if not catalog_root.exists():
|
| 98 |
+
raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
|
| 99 |
+
|
| 100 |
+
# Load catalog index to get all individuals
|
| 101 |
+
index = load_catalog_index(catalog_root)
|
| 102 |
+
|
| 103 |
+
# Check if extractor is available
|
| 104 |
+
available_extractors = index.get("feature_extractors", {})
|
| 105 |
+
if extractor not in available_extractors:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"Extractor '{extractor}' not available in catalog. "
|
| 108 |
+
f"Available: {list(available_extractors.keys())}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
catalog_features = {}
|
| 112 |
+
database_dir = catalog_root / "database"
|
| 113 |
+
|
| 114 |
+
# Load features for each individual
|
| 115 |
+
for individual in index["individuals"]:
|
| 116 |
+
# Support both 'leopard_name' and 'individual_name' keys
|
| 117 |
+
leopard_name = individual.get("leopard_name") or individual.get(
|
| 118 |
+
"individual_name"
|
| 119 |
+
)
|
| 120 |
+
location = individual.get("location", "")
|
| 121 |
+
|
| 122 |
+
# Construct path: database/{location}/{individual_name}/
|
| 123 |
+
if location:
|
| 124 |
+
leopard_dir = database_dir / location / leopard_name
|
| 125 |
+
else:
|
| 126 |
+
leopard_dir = database_dir / leopard_name
|
| 127 |
+
|
| 128 |
+
# Load leopard metadata to get all reference images
|
| 129 |
+
metadata_path = leopard_dir / "metadata.yaml"
|
| 130 |
+
metadata = load_leopard_metadata(metadata_path)
|
| 131 |
+
|
| 132 |
+
# Load features for each reference image
|
| 133 |
+
for ref_image in metadata["reference_images"]:
|
| 134 |
+
# Check if features exist for this extractor
|
| 135 |
+
if extractor not in ref_image.get("features", {}):
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
# Get feature path (relative to database directory in metadata)
|
| 139 |
+
feature_rel_path = ref_image["features"][extractor]
|
| 140 |
+
feature_path = database_dir / feature_rel_path
|
| 141 |
+
|
| 142 |
+
if not feature_path.exists():
|
| 143 |
+
# Skip missing features with a warning
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
# Create catalog ID: leopard_name_year_imagenum
|
| 147 |
+
# e.g., "naguima_2022_001"
|
| 148 |
+
image_id = ref_image["image_id"]
|
| 149 |
+
catalog_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
|
| 150 |
+
|
| 151 |
+
# Load features
|
| 152 |
+
try:
|
| 153 |
+
feats = torch.load(feature_path, map_location="cpu", weights_only=False)
|
| 154 |
+
catalog_features[catalog_id] = feats
|
| 155 |
+
except Exception:
|
| 156 |
+
# Skip files that can't be loaded
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
if not catalog_features:
|
| 160 |
+
raise ValueError(f"No features found for extractor '{extractor}' in catalog")
|
| 161 |
+
|
| 162 |
+
return catalog_features
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def get_filtered_catalog_features(
|
| 166 |
+
catalog_root: Path,
|
| 167 |
+
extractor: str = "sift",
|
| 168 |
+
locations: list[str] | None = None,
|
| 169 |
+
body_parts: list[str] | None = None,
|
| 170 |
+
) -> dict[str, dict[str, torch.Tensor]]:
|
| 171 |
+
"""Load filtered catalog features for a specific extractor.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
catalog_root: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
|
| 175 |
+
extractor: Feature extractor name (default: 'sift')
|
| 176 |
+
locations: List of locations to filter by (e.g., ["skycrest_valley", "silvershadow_highlands"]).
|
| 177 |
+
If None, includes all locations.
|
| 178 |
+
body_parts: List of body parts to filter by (e.g., ["head", "right_flank"]).
|
| 179 |
+
If None, includes all body parts.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Dictionary mapping catalog_id to feature dict:
|
| 183 |
+
{
|
| 184 |
+
"leopard1_2022_001": {
|
| 185 |
+
"keypoints": torch.Tensor,
|
| 186 |
+
"descriptors": torch.Tensor,
|
| 187 |
+
"scores": torch.Tensor,
|
| 188 |
+
...
|
| 189 |
+
},
|
| 190 |
+
...
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
Raises:
|
| 194 |
+
FileNotFoundError: If catalog doesn't exist
|
| 195 |
+
ValueError: If no features found for extractor or filters
|
| 196 |
+
"""
|
| 197 |
+
if not catalog_root.exists():
|
| 198 |
+
raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
|
| 199 |
+
|
| 200 |
+
# Load catalog index to get all individuals
|
| 201 |
+
index = load_catalog_index(catalog_root)
|
| 202 |
+
|
| 203 |
+
# Check if extractor is available
|
| 204 |
+
available_extractors = index.get("feature_extractors", {})
|
| 205 |
+
if extractor not in available_extractors:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Extractor '{extractor}' not available in catalog. "
|
| 208 |
+
f"Available: {list(available_extractors.keys())}"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
catalog_features = {}
|
| 212 |
+
database_dir = catalog_root / "database"
|
| 213 |
+
|
| 214 |
+
# Load features for each individual
|
| 215 |
+
for individual in index["individuals"]:
|
| 216 |
+
# Support both 'leopard_name' and 'individual_name' keys
|
| 217 |
+
leopard_name = individual.get("leopard_name") or individual.get(
|
| 218 |
+
"individual_name"
|
| 219 |
+
)
|
| 220 |
+
location = individual.get("location", "")
|
| 221 |
+
|
| 222 |
+
# Filter by location if specified
|
| 223 |
+
if locations is not None and location not in locations:
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# Construct path: database/{location}/{individual_name}/
|
| 227 |
+
if location:
|
| 228 |
+
leopard_dir = database_dir / location / leopard_name
|
| 229 |
+
else:
|
| 230 |
+
leopard_dir = database_dir / leopard_name
|
| 231 |
+
|
| 232 |
+
# Load leopard metadata to get all reference images
|
| 233 |
+
metadata_path = leopard_dir / "metadata.yaml"
|
| 234 |
+
metadata = load_leopard_metadata(metadata_path)
|
| 235 |
+
|
| 236 |
+
# Load features for each reference image
|
| 237 |
+
for ref_image in metadata["reference_images"]:
|
| 238 |
+
# Filter by body part if specified
|
| 239 |
+
if body_parts is not None:
|
| 240 |
+
ref_body_part = ref_image.get("body_part", "")
|
| 241 |
+
if ref_body_part not in body_parts:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
# Check if features exist for this extractor
|
| 245 |
+
if extractor not in ref_image.get("features", {}):
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
# Get feature path (relative to database directory in metadata)
|
| 249 |
+
feature_rel_path = ref_image["features"][extractor]
|
| 250 |
+
feature_path = database_dir / feature_rel_path
|
| 251 |
+
|
| 252 |
+
if not feature_path.exists():
|
| 253 |
+
# Skip missing features with a warning
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
# Create catalog ID: leopard_name_year_imagenum
|
| 257 |
+
# e.g., "naguima_2022_001"
|
| 258 |
+
image_id = ref_image["image_id"]
|
| 259 |
+
catalog_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
|
| 260 |
+
|
| 261 |
+
# Load features
|
| 262 |
+
try:
|
| 263 |
+
feats = torch.load(feature_path, map_location="cpu", weights_only=False)
|
| 264 |
+
catalog_features[catalog_id] = feats
|
| 265 |
+
except Exception:
|
| 266 |
+
# Skip files that can't be loaded
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
if not catalog_features:
|
| 270 |
+
filter_info = []
|
| 271 |
+
if locations:
|
| 272 |
+
filter_info.append(f"locations={locations}")
|
| 273 |
+
if body_parts:
|
| 274 |
+
filter_info.append(f"body_parts={body_parts}")
|
| 275 |
+
filter_str = ", ".join(filter_info) if filter_info else "no filters"
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"No features found for extractor '{extractor}' with {filter_str}"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return catalog_features
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_available_locations(catalog_root: Path) -> list[str]:
|
| 284 |
+
"""Get list of available locations from catalog.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
catalog_root: Path to catalog root directory
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
List of location names prepended with "all" (e.g., ["all", "skycrest_valley", "silvershadow_highlands"])
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
index = load_catalog_index(catalog_root)
|
| 294 |
+
locations = index.get("statistics", {}).get("locations", [])
|
| 295 |
+
return ["all"] + sorted(locations)
|
| 296 |
+
except Exception:
|
| 297 |
+
return ["all"]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_available_body_parts(catalog_root: Path) -> list[str]:
|
| 301 |
+
"""Get list of available body parts from catalog.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
catalog_root: Path to catalog root directory
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
List of body part names prepended with "all"
|
| 308 |
+
(e.g., ["all", "head", "left_flank", "right_flank", "tail", "misc"])
|
| 309 |
+
"""
|
| 310 |
+
try:
|
| 311 |
+
index = load_catalog_index(catalog_root)
|
| 312 |
+
body_parts = index.get("statistics", {}).get("body_parts", [])
|
| 313 |
+
return ["all"] + sorted(body_parts)
|
| 314 |
+
except Exception:
|
| 315 |
+
return ["all"]
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def get_catalog_metadata_for_id(
|
| 319 |
+
catalog_root: Path,
|
| 320 |
+
catalog_id: str,
|
| 321 |
+
) -> dict | None:
|
| 322 |
+
"""Get full metadata for a specific catalog ID.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
catalog_root: Path to catalog root directory
|
| 326 |
+
catalog_id: Catalog ID (e.g., "naguima_2022_001")
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dictionary with metadata including:
|
| 330 |
+
- leopard_name: str
|
| 331 |
+
- year: int
|
| 332 |
+
- image_path: Path
|
| 333 |
+
- individual_id: str
|
| 334 |
+
Or None if not found
|
| 335 |
+
|
| 336 |
+
Raises:
|
| 337 |
+
FileNotFoundError: If catalog doesn't exist
|
| 338 |
+
"""
|
| 339 |
+
if not catalog_root.exists():
|
| 340 |
+
raise FileNotFoundError(f"Catalog root not found: {catalog_root}")
|
| 341 |
+
|
| 342 |
+
# Load catalog index
|
| 343 |
+
index = load_catalog_index(catalog_root)
|
| 344 |
+
database_dir = catalog_root / "database"
|
| 345 |
+
|
| 346 |
+
# Try to find matching individual
|
| 347 |
+
for individual in index["individuals"]:
|
| 348 |
+
# Support both 'leopard_name' and 'individual_name' keys
|
| 349 |
+
leopard_name = individual.get("leopard_name") or individual.get(
|
| 350 |
+
"individual_name"
|
| 351 |
+
)
|
| 352 |
+
location = individual.get("location", "")
|
| 353 |
+
|
| 354 |
+
# Construct path: database/{location}/{individual_name}/
|
| 355 |
+
if location:
|
| 356 |
+
leopard_dir = database_dir / location / leopard_name
|
| 357 |
+
else:
|
| 358 |
+
leopard_dir = database_dir / leopard_name
|
| 359 |
+
|
| 360 |
+
# Load leopard metadata
|
| 361 |
+
metadata_path = leopard_dir / "metadata.yaml"
|
| 362 |
+
metadata = load_leopard_metadata(metadata_path)
|
| 363 |
+
|
| 364 |
+
# Check each reference image
|
| 365 |
+
for ref_image in metadata["reference_images"]:
|
| 366 |
+
# Construct expected catalog ID
|
| 367 |
+
image_id = ref_image["image_id"]
|
| 368 |
+
expected_id = f"{leopard_name.lower().replace(' ', '_')}_{image_id}"
|
| 369 |
+
|
| 370 |
+
if expected_id == catalog_id:
|
| 371 |
+
# Found match
|
| 372 |
+
return {
|
| 373 |
+
"leopard_name": leopard_name,
|
| 374 |
+
"image_path": database_dir / ref_image["path"],
|
| 375 |
+
"individual_id": metadata["individual_id"],
|
| 376 |
+
"filename": ref_image["filename"],
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
return None
|
src/snowleopard_reid/data_setup.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data setup utilities for extracting archives on first run.
|
| 2 |
+
|
| 3 |
+
This module handles the extraction of catalog and cache archives when the
|
| 4 |
+
application starts. Archives are extracted only once, on first run.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import tarfile
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Paths relative to project root
|
| 14 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 15 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 16 |
+
CATALOG_ARCHIVE = DATA_DIR / "catalog.tar.gz"
|
| 17 |
+
CATALOG_DIR = DATA_DIR / "catalog"
|
| 18 |
+
CACHE_ARCHIVE = DATA_DIR / "cache.tar.gz"
|
| 19 |
+
CACHE_DIR = PROJECT_ROOT / "cached_results"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def extract_archive(archive_path: Path, extract_to: Path) -> bool:
|
| 23 |
+
"""Extract a tar.gz archive to specified directory.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
archive_path: Path to the .tar.gz archive
|
| 27 |
+
extract_to: Directory to extract to (parent of archived dir)
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
True if extraction successful, False otherwise
|
| 31 |
+
"""
|
| 32 |
+
if not archive_path.exists():
|
| 33 |
+
logger.debug(f"Archive not found: {archive_path}")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
logger.info(f"Extracting {archive_path.name}...")
|
| 38 |
+
extract_to.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 41 |
+
tar.extractall(path=extract_to)
|
| 42 |
+
|
| 43 |
+
logger.info(f"Extracted {archive_path.name} successfully")
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"Failed to extract {archive_path.name}: {e}")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def ensure_data_extracted() -> None:
|
| 52 |
+
"""Ensure catalog and cache archives are extracted.
|
| 53 |
+
|
| 54 |
+
Call this function at application startup. It will:
|
| 55 |
+
- Check if catalog directory exists, extract from archive if not
|
| 56 |
+
- Check if cache directory exists, extract from archive if not
|
| 57 |
+
|
| 58 |
+
Archives are only extracted if:
|
| 59 |
+
1. The archive file exists
|
| 60 |
+
2. The target directory does not exist
|
| 61 |
+
|
| 62 |
+
This makes the function idempotent - safe to call multiple times.
|
| 63 |
+
"""
|
| 64 |
+
# Extract catalog if needed
|
| 65 |
+
if CATALOG_ARCHIVE.exists() and not CATALOG_DIR.exists():
|
| 66 |
+
logger.info("First run detected - extracting catalog data...")
|
| 67 |
+
extract_archive(CATALOG_ARCHIVE, DATA_DIR)
|
| 68 |
+
|
| 69 |
+
# Extract cache if needed
|
| 70 |
+
if CACHE_ARCHIVE.exists() and not CACHE_DIR.exists():
|
| 71 |
+
logger.info("First run detected - extracting cached results...")
|
| 72 |
+
extract_archive(CACHE_ARCHIVE, PROJECT_ROOT)
|
| 73 |
+
|
| 74 |
+
# Log status
|
| 75 |
+
catalog_ready = CATALOG_DIR.exists()
|
| 76 |
+
cache_ready = CACHE_DIR.exists()
|
| 77 |
+
|
| 78 |
+
if catalog_ready and cache_ready:
|
| 79 |
+
logger.debug("All data directories ready")
|
| 80 |
+
else:
|
| 81 |
+
if not catalog_ready:
|
| 82 |
+
logger.warning(f"Catalog not available: {CATALOG_DIR}")
|
| 83 |
+
if not cache_ready:
|
| 84 |
+
logger.warning(f"Cache not available: {CACHE_DIR}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def is_data_ready() -> bool:
|
| 88 |
+
"""Check if all required data directories exist.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
True if catalog directory exists, False otherwise
|
| 92 |
+
"""
|
| 93 |
+
return CATALOG_DIR.exists()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def is_cache_ready() -> bool:
|
| 97 |
+
"""Check if cache directory exists.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
True if cache directory exists, False otherwise
|
| 101 |
+
"""
|
| 102 |
+
return CACHE_DIR.exists()
|
src/snowleopard_reid/features/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Features module for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module provides utilities for extracting, loading, and saving features
|
| 4 |
+
from snow leopard images using various feature extractors (SIFT, SuperPoint, DISK, ALIKED).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .extraction import (
|
| 8 |
+
extract_aliked_features,
|
| 9 |
+
extract_disk_features,
|
| 10 |
+
extract_features,
|
| 11 |
+
extract_sift_features,
|
| 12 |
+
extract_superpoint_features,
|
| 13 |
+
get_num_keypoints,
|
| 14 |
+
load_features,
|
| 15 |
+
save_features,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"extract_features",
|
| 20 |
+
"extract_sift_features",
|
| 21 |
+
"extract_superpoint_features",
|
| 22 |
+
"extract_disk_features",
|
| 23 |
+
"extract_aliked_features",
|
| 24 |
+
"load_features",
|
| 25 |
+
"save_features",
|
| 26 |
+
"get_num_keypoints",
|
| 27 |
+
]
|
src/snowleopard_reid/features/extraction.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for feature extraction and management.
|
| 2 |
+
|
| 3 |
+
This module provides functions for extracting, loading, and saving image features
|
| 4 |
+
using various feature extractors (SIFT, SuperPoint, DISK, ALIKED).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from lightglue import ALIKED, DISK, SIFT, SuperPoint
|
| 11 |
+
from lightglue.utils import load_image, rbd
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_sift_features(
|
| 16 |
+
image_path: Path | str | Image.Image,
|
| 17 |
+
max_num_keypoints: int = 2048,
|
| 18 |
+
device: str = "cpu",
|
| 19 |
+
) -> dict[str, torch.Tensor]:
|
| 20 |
+
"""Extract SIFT features from an image.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
image_path: Path to image file or PIL Image object
|
| 24 |
+
max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
|
| 25 |
+
device: Device to run extraction on ('cpu' or 'cuda')
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Dictionary with keys:
|
| 29 |
+
- keypoints: Tensor of shape [N, 2] with (x, y) coordinates
|
| 30 |
+
- descriptors: Tensor of shape [N, 128] with SIFT descriptors
|
| 31 |
+
- scores: Tensor of shape [N] with keypoint scores
|
| 32 |
+
- image_size: Tensor of shape [2] with (width, height)
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
ValueError: If max_num_keypoints is out of valid range
|
| 36 |
+
FileNotFoundError: If image_path is a string/Path and file doesn't exist
|
| 37 |
+
"""
|
| 38 |
+
# Validate parameters
|
| 39 |
+
if not (512 <= max_num_keypoints <= 4096):
|
| 40 |
+
raise ValueError(
|
| 41 |
+
f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Initialize extractor
|
| 45 |
+
extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(device)
|
| 46 |
+
|
| 47 |
+
# Load image
|
| 48 |
+
if isinstance(image_path, (str, Path)):
|
| 49 |
+
image_path = Path(image_path)
|
| 50 |
+
if not image_path.exists():
|
| 51 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 52 |
+
# load_image returns torch.Tensor [3, H, W]
|
| 53 |
+
image = load_image(str(image_path))
|
| 54 |
+
elif isinstance(image_path, Image.Image):
|
| 55 |
+
# Convert PIL Image to path temporarily
|
| 56 |
+
# For now, require path input for lightglue compatibility
|
| 57 |
+
raise TypeError(
|
| 58 |
+
"PIL Image input not yet supported, please provide path to image file"
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
raise TypeError(
|
| 62 |
+
f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Move image to device
|
| 66 |
+
image = image.to(device)
|
| 67 |
+
|
| 68 |
+
# Extract features
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
feats = extractor.extract(image) # auto-resizes image
|
| 71 |
+
feats = rbd(feats) # remove batch dimension
|
| 72 |
+
|
| 73 |
+
# Move features back to CPU for storage
|
| 74 |
+
if device != "cpu":
|
| 75 |
+
feats = {
|
| 76 |
+
k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
return feats
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def extract_superpoint_features(
|
| 83 |
+
image_path: Path | str | Image.Image,
|
| 84 |
+
max_num_keypoints: int = 2048,
|
| 85 |
+
device: str = "cpu",
|
| 86 |
+
) -> dict[str, torch.Tensor]:
|
| 87 |
+
"""Extract SuperPoint features from an image.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
image_path: Path to image file or PIL Image object
|
| 91 |
+
max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
|
| 92 |
+
device: Device to run extraction on ('cpu' or 'cuda')
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dictionary with keys:
|
| 96 |
+
- keypoints: Tensor of shape [N, 2] with (x, y) coordinates
|
| 97 |
+
- descriptors: Tensor of shape [N, 256] with SuperPoint descriptors
|
| 98 |
+
- scores: Tensor of shape [N] with keypoint scores
|
| 99 |
+
- image_size: Tensor of shape [2] with (width, height)
|
| 100 |
+
|
| 101 |
+
Raises:
|
| 102 |
+
ValueError: If max_num_keypoints is out of valid range
|
| 103 |
+
FileNotFoundError: If image_path is a string/Path and file doesn't exist
|
| 104 |
+
"""
|
| 105 |
+
# Validate parameters
|
| 106 |
+
if not (512 <= max_num_keypoints <= 4096):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Initialize extractor
|
| 112 |
+
extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device)
|
| 113 |
+
|
| 114 |
+
# Load image
|
| 115 |
+
if isinstance(image_path, (str, Path)):
|
| 116 |
+
image_path = Path(image_path)
|
| 117 |
+
if not image_path.exists():
|
| 118 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 119 |
+
image = load_image(str(image_path))
|
| 120 |
+
elif isinstance(image_path, Image.Image):
|
| 121 |
+
raise TypeError(
|
| 122 |
+
"PIL Image input not yet supported, please provide path to image file"
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
raise TypeError(
|
| 126 |
+
f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Move image to device
|
| 130 |
+
image = image.to(device)
|
| 131 |
+
|
| 132 |
+
# Extract features
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
feats = extractor.extract(image)
|
| 135 |
+
feats = rbd(feats) # remove batch dimension
|
| 136 |
+
|
| 137 |
+
# Move features back to CPU for storage
|
| 138 |
+
if device != "cpu":
|
| 139 |
+
feats = {
|
| 140 |
+
k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
return feats
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def extract_disk_features(
|
| 147 |
+
image_path: Path | str | Image.Image,
|
| 148 |
+
max_num_keypoints: int = 2048,
|
| 149 |
+
device: str = "cpu",
|
| 150 |
+
) -> dict[str, torch.Tensor]:
|
| 151 |
+
"""Extract DISK features from an image.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
image_path: Path to image file or PIL Image object
|
| 155 |
+
max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
|
| 156 |
+
device: Device to run extraction on ('cpu' or 'cuda')
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Dictionary with keys:
|
| 160 |
+
- keypoints: Tensor of shape [N, 2] with (x, y) coordinates
|
| 161 |
+
- descriptors: Tensor of shape [N, 128] with DISK descriptors
|
| 162 |
+
- scores: Tensor of shape [N] with keypoint scores
|
| 163 |
+
- image_size: Tensor of shape [2] with (width, height)
|
| 164 |
+
|
| 165 |
+
Raises:
|
| 166 |
+
ValueError: If max_num_keypoints is out of valid range
|
| 167 |
+
FileNotFoundError: If image_path is a string/Path and file doesn't exist
|
| 168 |
+
"""
|
| 169 |
+
# Validate parameters
|
| 170 |
+
if not (512 <= max_num_keypoints <= 4096):
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Initialize extractor
|
| 176 |
+
extractor = DISK(max_num_keypoints=max_num_keypoints).eval().to(device)
|
| 177 |
+
|
| 178 |
+
# Load image
|
| 179 |
+
if isinstance(image_path, (str, Path)):
|
| 180 |
+
image_path = Path(image_path)
|
| 181 |
+
if not image_path.exists():
|
| 182 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 183 |
+
image = load_image(str(image_path))
|
| 184 |
+
elif isinstance(image_path, Image.Image):
|
| 185 |
+
raise TypeError(
|
| 186 |
+
"PIL Image input not yet supported, please provide path to image file"
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
raise TypeError(
|
| 190 |
+
f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Move image to device
|
| 194 |
+
image = image.to(device)
|
| 195 |
+
|
| 196 |
+
# Extract features
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
feats = extractor.extract(image)
|
| 199 |
+
feats = rbd(feats) # remove batch dimension
|
| 200 |
+
|
| 201 |
+
# Move features back to CPU for storage
|
| 202 |
+
if device != "cpu":
|
| 203 |
+
feats = {
|
| 204 |
+
k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
return feats
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def extract_aliked_features(
|
| 211 |
+
image_path: Path | str | Image.Image,
|
| 212 |
+
max_num_keypoints: int = 2048,
|
| 213 |
+
device: str = "cpu",
|
| 214 |
+
) -> dict[str, torch.Tensor]:
|
| 215 |
+
"""Extract ALIKED features from an image.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
image_path: Path to image file or PIL Image object
|
| 219 |
+
max_num_keypoints: Maximum number of keypoints to extract (default: 2048, range: 512-4096)
|
| 220 |
+
device: Device to run extraction on ('cpu' or 'cuda')
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dictionary with keys:
|
| 224 |
+
- keypoints: Tensor of shape [N, 2] with (x, y) coordinates
|
| 225 |
+
- descriptors: Tensor of shape [N, 128] with ALIKED descriptors
|
| 226 |
+
- scores: Tensor of shape [N] with keypoint scores
|
| 227 |
+
- image_size: Tensor of shape [2] with (width, height)
|
| 228 |
+
|
| 229 |
+
Raises:
|
| 230 |
+
ValueError: If max_num_keypoints is out of valid range
|
| 231 |
+
FileNotFoundError: If image_path is a string/Path and file doesn't exist
|
| 232 |
+
"""
|
| 233 |
+
# Validate parameters
|
| 234 |
+
if not (512 <= max_num_keypoints <= 4096):
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"max_num_keypoints must be in range 512-4096, got {max_num_keypoints}"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Initialize extractor
|
| 240 |
+
extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().to(device)
|
| 241 |
+
|
| 242 |
+
# Load image
|
| 243 |
+
if isinstance(image_path, (str, Path)):
|
| 244 |
+
image_path = Path(image_path)
|
| 245 |
+
if not image_path.exists():
|
| 246 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 247 |
+
image = load_image(str(image_path))
|
| 248 |
+
elif isinstance(image_path, Image.Image):
|
| 249 |
+
raise TypeError(
|
| 250 |
+
"PIL Image input not yet supported, please provide path to image file"
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
raise TypeError(
|
| 254 |
+
f"image_path must be str, Path, or PIL Image, got {type(image_path)}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Move image to device
|
| 258 |
+
image = image.to(device)
|
| 259 |
+
|
| 260 |
+
# Extract features
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
feats = extractor.extract(image)
|
| 263 |
+
feats = rbd(feats) # remove batch dimension
|
| 264 |
+
|
| 265 |
+
# Move features back to CPU for storage
|
| 266 |
+
if device != "cpu":
|
| 267 |
+
feats = {
|
| 268 |
+
k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in feats.items()
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
return feats
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def load_features(features_path: Path | str) -> dict[str, torch.Tensor]:
|
| 275 |
+
"""Load features from a PyTorch .pt file.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
features_path: Path to .pt file containing features
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Dictionary with feature tensors (keypoints, descriptors, scores, etc.)
|
| 282 |
+
|
| 283 |
+
Raises:
|
| 284 |
+
FileNotFoundError: If features file doesn't exist
|
| 285 |
+
RuntimeError: If file cannot be loaded
|
| 286 |
+
"""
|
| 287 |
+
features_path = Path(features_path)
|
| 288 |
+
|
| 289 |
+
if not features_path.exists():
|
| 290 |
+
raise FileNotFoundError(f"Features file not found: {features_path}")
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
feats = torch.load(features_path, map_location="cpu", weights_only=False)
|
| 294 |
+
return feats
|
| 295 |
+
except Exception as e:
|
| 296 |
+
raise RuntimeError(f"Failed to load features from {features_path}: {e}")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def save_features(
|
| 300 |
+
features: dict[str, torch.Tensor],
|
| 301 |
+
output_path: Path | str,
|
| 302 |
+
create_dirs: bool = True,
|
| 303 |
+
) -> None:
|
| 304 |
+
"""Save features to a PyTorch .pt file.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
features: Dictionary with feature tensors to save
|
| 308 |
+
output_path: Path where to save .pt file
|
| 309 |
+
create_dirs: Whether to create parent directories if they don't exist
|
| 310 |
+
|
| 311 |
+
Raises:
|
| 312 |
+
ValueError: If features dict is empty or invalid
|
| 313 |
+
OSError: If directory creation or file writing fails
|
| 314 |
+
"""
|
| 315 |
+
if not features:
|
| 316 |
+
raise ValueError("Features dictionary is empty")
|
| 317 |
+
|
| 318 |
+
output_path = Path(output_path)
|
| 319 |
+
|
| 320 |
+
# Create parent directories if needed
|
| 321 |
+
if create_dirs and not output_path.parent.exists():
|
| 322 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
torch.save(features, output_path)
|
| 326 |
+
except Exception as e:
|
| 327 |
+
raise OSError(f"Failed to save features to {output_path}: {e}")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_num_keypoints(features: dict[str, torch.Tensor]) -> int:
|
| 331 |
+
"""Get the number of keypoints from a features dictionary.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
features: Dictionary with 'keypoints' tensor
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Number of keypoints (first dimension of keypoints tensor)
|
| 338 |
+
|
| 339 |
+
Raises:
|
| 340 |
+
KeyError: If 'keypoints' key is missing
|
| 341 |
+
"""
|
| 342 |
+
if "keypoints" not in features:
|
| 343 |
+
raise KeyError("Features dictionary missing 'keypoints' key")
|
| 344 |
+
|
| 345 |
+
return features["keypoints"].shape[0]
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def extract_features(
|
| 349 |
+
extractor: str,
|
| 350 |
+
image_path: Path | str | Image.Image,
|
| 351 |
+
max_num_keypoints: int = 2048,
|
| 352 |
+
device: str = "cpu",
|
| 353 |
+
) -> dict[str, torch.Tensor]:
|
| 354 |
+
"""Extract features from an image using the specified extractor.
|
| 355 |
+
|
| 356 |
+
Factory function that dispatches to the appropriate feature extractor.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
extractor: Feature extractor name ('sift', 'superpoint', 'disk', 'aliked')
|
| 360 |
+
image_path: Path to image file or PIL Image object
|
| 361 |
+
max_num_keypoints: Maximum number of keypoints to extract (default: 2048)
|
| 362 |
+
device: Device to run extraction on ('cpu' or 'cuda')
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Dictionary with feature tensors (keypoints, descriptors, scores, image_size)
|
| 366 |
+
|
| 367 |
+
Raises:
|
| 368 |
+
ValueError: If extractor name is not supported
|
| 369 |
+
FileNotFoundError: If image_path is a string/Path and file doesn't exist
|
| 370 |
+
|
| 371 |
+
Examples:
|
| 372 |
+
>>> features = extract_features("sift", "image.jpg")
|
| 373 |
+
>>> features = extract_features("sift", "image.jpg", max_num_keypoints=4096, device="cuda")
|
| 374 |
+
"""
|
| 375 |
+
extractor = extractor.lower()
|
| 376 |
+
|
| 377 |
+
if extractor == "sift":
|
| 378 |
+
return extract_sift_features(image_path, max_num_keypoints, device)
|
| 379 |
+
elif extractor == "superpoint":
|
| 380 |
+
return extract_superpoint_features(image_path, max_num_keypoints, device)
|
| 381 |
+
elif extractor == "disk":
|
| 382 |
+
return extract_disk_features(image_path, max_num_keypoints, device)
|
| 383 |
+
elif extractor == "aliked":
|
| 384 |
+
return extract_aliked_features(image_path, max_num_keypoints, device)
|
| 385 |
+
else:
|
| 386 |
+
raise ValueError(
|
| 387 |
+
f"Unsupported extractor: {extractor}. Supported extractors: sift, superpoint, disk, aliked"
|
| 388 |
+
)
|
src/snowleopard_reid/images/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processing utilities for the Snow Leopard Re-ID project."""
|
| 2 |
+
|
| 3 |
+
from snowleopard_reid.images.processing import (
|
| 4 |
+
resize_image_if_needed,
|
| 5 |
+
resize_image_if_needed_cv2,
|
| 6 |
+
resize_image_if_needed_pil,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"resize_image_if_needed",
|
| 11 |
+
"resize_image_if_needed_cv2",
|
| 12 |
+
"resize_image_if_needed_pil",
|
| 13 |
+
]
|
src/snowleopard_reid/images/processing.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processing utilities for the Snow Leopard Re-ID project."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def resize_image_if_needed_pil(image: Image.Image, max_dim: int = 1024) -> Image.Image:
|
| 13 |
+
"""
|
| 14 |
+
Resize PIL Image if either dimension exceeds max_dim, maintaining aspect ratio.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image: PIL Image to resize
|
| 18 |
+
max_dim: Maximum allowed dimension (default: 1024)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Resized image (or original if no resize needed)
|
| 22 |
+
"""
|
| 23 |
+
width, height = image.size
|
| 24 |
+
|
| 25 |
+
if width <= max_dim and height <= max_dim:
|
| 26 |
+
return image
|
| 27 |
+
|
| 28 |
+
# Calculate scaling factor
|
| 29 |
+
scale = min(max_dim / width, max_dim / height)
|
| 30 |
+
|
| 31 |
+
# Calculate new dimensions
|
| 32 |
+
new_width = int(width * scale)
|
| 33 |
+
new_height = int(height * scale)
|
| 34 |
+
|
| 35 |
+
# Resize image using high-quality LANCZOS filter
|
| 36 |
+
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 37 |
+
logger.debug(f"Resized image from {width}x{height} to {new_width}x{new_height}")
|
| 38 |
+
|
| 39 |
+
return resized
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resize_image_if_needed_cv2(img: np.ndarray, max_dim: int = 1024) -> np.ndarray:
|
| 43 |
+
"""
|
| 44 |
+
Resize cv2 image if either dimension exceeds max_dim, maintaining aspect ratio.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
img: Input image as numpy array
|
| 48 |
+
max_dim: Maximum allowed dimension (default: 1024)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Resized image (or original if no resize needed)
|
| 52 |
+
"""
|
| 53 |
+
height, width = img.shape[:2]
|
| 54 |
+
|
| 55 |
+
if height <= max_dim and width <= max_dim:
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
# Calculate scaling factor
|
| 59 |
+
scale = min(max_dim / height, max_dim / width)
|
| 60 |
+
|
| 61 |
+
# Calculate new dimensions
|
| 62 |
+
new_width = int(width * scale)
|
| 63 |
+
new_height = int(height * scale)
|
| 64 |
+
|
| 65 |
+
# Resize image using INTER_AREA (best for downscaling)
|
| 66 |
+
resized = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 67 |
+
logger.debug(f"Resized image from {width}x{height} to {new_width}x{new_height}")
|
| 68 |
+
|
| 69 |
+
return resized
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def resize_image_if_needed(
|
| 73 |
+
image: Image.Image | np.ndarray, max_dim: int = 1024
|
| 74 |
+
) -> Image.Image | np.ndarray:
|
| 75 |
+
"""
|
| 76 |
+
Resize image if either dimension exceeds max_dim, maintaining aspect ratio.
|
| 77 |
+
|
| 78 |
+
Automatically detects whether the input is a PIL Image or numpy array
|
| 79 |
+
and applies the appropriate resize function.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
image: PIL Image or numpy array to resize
|
| 83 |
+
max_dim: Maximum allowed dimension (default: 1024)
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Resized image (or original if no resize needed)
|
| 87 |
+
"""
|
| 88 |
+
if isinstance(image, Image.Image):
|
| 89 |
+
return resize_image_if_needed_pil(image=image, max_dim=max_dim)
|
| 90 |
+
elif isinstance(image, np.ndarray):
|
| 91 |
+
return resize_image_if_needed_cv2(img=image, max_dim=max_dim)
|
| 92 |
+
else:
|
| 93 |
+
raise TypeError(f"Expected PIL.Image.Image or np.ndarray, got {type(image)}")
|
src/snowleopard_reid/masks/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mask processing utilities for snow leopard re-identification."""
|
| 2 |
+
|
| 3 |
+
from snowleopard_reid.masks.processing import (
|
| 4 |
+
add_padding_to_bbox,
|
| 5 |
+
crop_and_mask_image,
|
| 6 |
+
get_mask_bbox,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"get_mask_bbox",
|
| 11 |
+
"add_padding_to_bbox",
|
| 12 |
+
"crop_and_mask_image",
|
| 13 |
+
]
|
src/snowleopard_reid/masks/processing.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for mask processing and image cropping.
|
| 2 |
+
|
| 3 |
+
This module provides functions for working with binary segmentation masks,
|
| 4 |
+
calculating bounding boxes, and cropping images with masks applied.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_mask_bbox(mask: np.ndarray) -> tuple[int, int, int, int]:
|
| 12 |
+
"""Calculate the tight bounding box of a binary mask.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
mask: Binary mask array (0=background, 255=foreground)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tuple of (x_min, y_min, x_max, y_max) in pixel coordinates
|
| 19 |
+
|
| 20 |
+
Raises:
|
| 21 |
+
ValueError: If mask is empty (no foreground pixels)
|
| 22 |
+
"""
|
| 23 |
+
# Find all pixels that are part of the mask
|
| 24 |
+
rows = np.any(mask > 0, axis=1)
|
| 25 |
+
cols = np.any(mask > 0, axis=0)
|
| 26 |
+
|
| 27 |
+
if not np.any(rows) or not np.any(cols):
|
| 28 |
+
raise ValueError("Mask is empty (no foreground pixels)")
|
| 29 |
+
|
| 30 |
+
y_min, y_max = np.where(rows)[0][[0, -1]]
|
| 31 |
+
x_min, x_max = np.where(cols)[0][[0, -1]]
|
| 32 |
+
|
| 33 |
+
return int(x_min), int(y_min), int(x_max), int(y_max)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def add_padding_to_bbox(
|
| 37 |
+
bbox: tuple[int, int, int, int],
|
| 38 |
+
padding: int,
|
| 39 |
+
image_width: int,
|
| 40 |
+
image_height: int,
|
| 41 |
+
) -> tuple[int, int, int, int]:
|
| 42 |
+
"""Add padding to a bounding box, clamped to image boundaries.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
bbox: Original bounding box (x_min, y_min, x_max, y_max)
|
| 46 |
+
padding: Padding in pixels to add on all sides
|
| 47 |
+
image_width: Image width for clamping
|
| 48 |
+
image_height: Image height for clamping
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Padded bounding box (x_min, y_min, x_max, y_max)
|
| 52 |
+
"""
|
| 53 |
+
x_min, y_min, x_max, y_max = bbox
|
| 54 |
+
|
| 55 |
+
x_min = max(0, x_min - padding)
|
| 56 |
+
y_min = max(0, y_min - padding)
|
| 57 |
+
x_max = min(image_width - 1, x_max + padding)
|
| 58 |
+
y_max = min(image_height - 1, y_max + padding)
|
| 59 |
+
|
| 60 |
+
return x_min, y_min, x_max, y_max
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def crop_and_mask_image(
|
| 64 |
+
image: Image.Image,
|
| 65 |
+
mask: np.ndarray,
|
| 66 |
+
bbox: tuple[int, int, int, int],
|
| 67 |
+
) -> Image.Image:
|
| 68 |
+
"""Crop image to bbox and apply mask with black background.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
image: Original PIL Image
|
| 72 |
+
mask: Binary mask array (same size as image, 0=background, 255=foreground)
|
| 73 |
+
bbox: Bounding box to crop to (x_min, y_min, x_max, y_max)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Cropped and masked PIL Image with black background
|
| 77 |
+
"""
|
| 78 |
+
x_min, y_min, x_max, y_max = bbox
|
| 79 |
+
|
| 80 |
+
# Crop image and mask to bounding box
|
| 81 |
+
cropped_image = image.crop((x_min, y_min, x_max + 1, y_max + 1))
|
| 82 |
+
cropped_mask = mask[y_min : y_max + 1, x_min : x_max + 1]
|
| 83 |
+
|
| 84 |
+
# Convert image to numpy array
|
| 85 |
+
image_array = np.array(cropped_image)
|
| 86 |
+
|
| 87 |
+
# Create mask with correct shape (add channel dimension if needed)
|
| 88 |
+
if len(image_array.shape) == 3:
|
| 89 |
+
# RGB image - expand mask to 3 channels
|
| 90 |
+
mask_3d = np.repeat(cropped_mask[:, :, np.newaxis] > 0, 3, axis=2)
|
| 91 |
+
else:
|
| 92 |
+
# Grayscale image
|
| 93 |
+
mask_3d = cropped_mask > 0
|
| 94 |
+
|
| 95 |
+
# Apply mask: keep original pixels where mask is True, black elsewhere
|
| 96 |
+
masked_array = np.where(mask_3d, image_array, 0)
|
| 97 |
+
|
| 98 |
+
# Convert back to PIL Image
|
| 99 |
+
return Image.fromarray(masked_array.astype(np.uint8))
|
src/snowleopard_reid/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline module for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module provides the complete end-to-end pipeline for identifying individual
|
| 4 |
+
snow leopards from query images.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .stages import (
|
| 8 |
+
run_feature_extraction_stage,
|
| 9 |
+
run_mask_selection_stage,
|
| 10 |
+
run_matching_stage,
|
| 11 |
+
run_preprocess_stage,
|
| 12 |
+
run_segmentation_stage,
|
| 13 |
+
select_best_mask,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"run_segmentation_stage",
|
| 18 |
+
"run_mask_selection_stage",
|
| 19 |
+
"run_preprocess_stage",
|
| 20 |
+
"run_feature_extraction_stage",
|
| 21 |
+
"run_matching_stage",
|
| 22 |
+
"select_best_mask",
|
| 23 |
+
]
|
src/snowleopard_reid/pipeline/stages/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline stages for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module contains all pipeline stages that process query images through
|
| 4 |
+
segmentation, feature extraction, and matching.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .feature_extraction import run_feature_extraction_stage
|
| 8 |
+
from .mask_selection import run_mask_selection_stage, select_best_mask
|
| 9 |
+
from .matching import run_matching_stage
|
| 10 |
+
from .preprocess import run_preprocess_stage
|
| 11 |
+
from .segmentation import run_segmentation_stage
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"run_segmentation_stage",
|
| 15 |
+
"run_mask_selection_stage",
|
| 16 |
+
"select_best_mask",
|
| 17 |
+
"run_preprocess_stage",
|
| 18 |
+
"run_feature_extraction_stage",
|
| 19 |
+
"run_matching_stage",
|
| 20 |
+
]
|
src/snowleopard_reid/pipeline/stages/feature_extraction.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Feature extraction stage for query images.
|
| 2 |
+
|
| 3 |
+
This module extracts features from cropped query images for matching
|
| 4 |
+
against the catalog.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import tempfile
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from snowleopard_reid import features, get_device
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_feature_extraction_stage(
|
| 20 |
+
image: Image.Image | Path | str,
|
| 21 |
+
extractor: str = "sift",
|
| 22 |
+
max_keypoints: int = 2048,
|
| 23 |
+
device: str | None = None,
|
| 24 |
+
) -> dict:
|
| 25 |
+
"""Extract features from query image.
|
| 26 |
+
|
| 27 |
+
This stage extracts keypoints and descriptors from the preprocessed query
|
| 28 |
+
image using the specified feature extractor.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
image: PIL Image object or path to image file
|
| 32 |
+
extractor: Feature extractor to use (default: 'sift')
|
| 33 |
+
max_keypoints: Maximum number of keypoints to extract (default: 2048)
|
| 34 |
+
device: Device to run on ('cpu', 'cuda', or None for auto-detect)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Stage dict with structure:
|
| 38 |
+
{
|
| 39 |
+
"stage_id": "feature_extraction",
|
| 40 |
+
"stage_name": "Feature Extraction",
|
| 41 |
+
"description": "Extract keypoints and descriptors",
|
| 42 |
+
"config": {
|
| 43 |
+
"extractor": str,
|
| 44 |
+
"max_keypoints": int,
|
| 45 |
+
"device": str
|
| 46 |
+
},
|
| 47 |
+
"metrics": {
|
| 48 |
+
"num_keypoints": int
|
| 49 |
+
},
|
| 50 |
+
"data": {
|
| 51 |
+
"features": {
|
| 52 |
+
"keypoints": torch.Tensor [N, 2],
|
| 53 |
+
"descriptors": torch.Tensor [N, D],
|
| 54 |
+
"scores": torch.Tensor [N],
|
| 55 |
+
"image_size": torch.Tensor [2]
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
ValueError: If extractor is not supported
|
| 62 |
+
FileNotFoundError: If image path doesn't exist
|
| 63 |
+
RuntimeError: If feature extraction fails
|
| 64 |
+
"""
|
| 65 |
+
# Auto-detect device if not specified
|
| 66 |
+
device = get_device(device=device, verbose=True)
|
| 67 |
+
|
| 68 |
+
# Extract features using the factory function
|
| 69 |
+
features_dict = _extract_features_from_image(
|
| 70 |
+
image=image, extractor=extractor, max_keypoints=max_keypoints, device=device
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Get number of keypoints
|
| 74 |
+
num_kpts = features.get_num_keypoints(features_dict)
|
| 75 |
+
logger.info(f"Extracted {num_kpts} keypoints using {extractor.upper()}")
|
| 76 |
+
|
| 77 |
+
# Return standardized stage dict
|
| 78 |
+
return {
|
| 79 |
+
"stage_id": "feature_extraction",
|
| 80 |
+
"stage_name": "Feature Extraction",
|
| 81 |
+
"description": "Extract keypoints and descriptors",
|
| 82 |
+
"config": {
|
| 83 |
+
"extractor": extractor,
|
| 84 |
+
"max_keypoints": max_keypoints,
|
| 85 |
+
"device": device,
|
| 86 |
+
},
|
| 87 |
+
"metrics": {
|
| 88 |
+
"num_keypoints": num_kpts,
|
| 89 |
+
},
|
| 90 |
+
"data": {
|
| 91 |
+
"features": features_dict,
|
| 92 |
+
},
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _extract_features_from_image(
|
| 97 |
+
image: Image.Image | Path | str,
|
| 98 |
+
extractor: str,
|
| 99 |
+
max_keypoints: int,
|
| 100 |
+
device: str,
|
| 101 |
+
) -> dict[str, torch.Tensor]:
|
| 102 |
+
"""Extract features from PIL Image or path using specified extractor.
|
| 103 |
+
|
| 104 |
+
This is a wrapper that handles PIL Image input by saving to a temporary file,
|
| 105 |
+
since lightglue's load_image() requires a file path.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
image: PIL Image or path to image
|
| 109 |
+
extractor: Feature extractor to use ('sift', 'superpoint', 'disk', 'aliked')
|
| 110 |
+
max_keypoints: Maximum keypoints to extract
|
| 111 |
+
device: Device to use
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Features dictionary
|
| 115 |
+
"""
|
| 116 |
+
if isinstance(image, Image.Image):
|
| 117 |
+
# Save PIL Image to temporary file
|
| 118 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
| 119 |
+
tmp_path = Path(tmp.name)
|
| 120 |
+
image.save(tmp_path, quality=95)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# Extract features from temporary file
|
| 124 |
+
feats = features.extract_features(
|
| 125 |
+
extractor, tmp_path, max_keypoints, device
|
| 126 |
+
)
|
| 127 |
+
finally:
|
| 128 |
+
# Clean up temporary file
|
| 129 |
+
tmp_path.unlink()
|
| 130 |
+
|
| 131 |
+
return feats
|
| 132 |
+
else:
|
| 133 |
+
# Image is already a path
|
| 134 |
+
return features.extract_features(extractor, image, max_keypoints, device)
|
src/snowleopard_reid/pipeline/stages/mask_selection.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mask selection stage for choosing the best snow leopard mask.
|
| 2 |
+
|
| 3 |
+
This module provides logic for selecting the best mask from multiple YOLO predictions.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def select_best_mask(
|
| 14 |
+
predictions: list[dict],
|
| 15 |
+
strategy: str = "confidence_area",
|
| 16 |
+
) -> tuple[int, dict]:
|
| 17 |
+
"""Select the best mask from predictions using specified strategy.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
predictions: List of prediction dicts from YOLO segmentation stage
|
| 21 |
+
strategy: Selection strategy ('confidence_area', 'confidence', 'area', 'center')
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tuple of (selected_index, selected_prediction)
|
| 25 |
+
|
| 26 |
+
Raises:
|
| 27 |
+
ValueError: If predictions list is empty or strategy is invalid
|
| 28 |
+
"""
|
| 29 |
+
if not predictions:
|
| 30 |
+
raise ValueError("Predictions list is empty")
|
| 31 |
+
|
| 32 |
+
valid_strategies = ["confidence_area", "confidence", "area", "center"]
|
| 33 |
+
if strategy not in valid_strategies:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f"Invalid strategy '{strategy}'. Valid strategies: {valid_strategies}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if strategy == "confidence_area":
|
| 39 |
+
# Select mask with highest confidence * area product
|
| 40 |
+
scores = []
|
| 41 |
+
for pred in predictions:
|
| 42 |
+
confidence = pred["confidence"]
|
| 43 |
+
mask = pred["mask"]
|
| 44 |
+
area = np.sum(mask > 0)
|
| 45 |
+
scores.append(confidence * area)
|
| 46 |
+
selected_idx = int(np.argmax(scores))
|
| 47 |
+
|
| 48 |
+
elif strategy == "confidence":
|
| 49 |
+
# Select mask with highest confidence
|
| 50 |
+
confidences = [pred["confidence"] for pred in predictions]
|
| 51 |
+
selected_idx = int(np.argmax(confidences))
|
| 52 |
+
|
| 53 |
+
elif strategy == "area":
|
| 54 |
+
# Select mask with largest area
|
| 55 |
+
areas = [np.sum(pred["mask"] > 0) for pred in predictions]
|
| 56 |
+
selected_idx = int(np.argmax(areas))
|
| 57 |
+
|
| 58 |
+
elif strategy == "center":
|
| 59 |
+
# Select mask closest to image center
|
| 60 |
+
# This strategy requires image size, which we can get from bbox
|
| 61 |
+
distances = []
|
| 62 |
+
for pred in predictions:
|
| 63 |
+
bbox = pred["bbox_xywhn"]
|
| 64 |
+
# Center is already normalized to [0, 1]
|
| 65 |
+
x_center = bbox["x_center"]
|
| 66 |
+
y_center = bbox["y_center"]
|
| 67 |
+
# Distance from image center (0.5, 0.5)
|
| 68 |
+
dist = np.sqrt((x_center - 0.5) ** 2 + (y_center - 0.5) ** 2)
|
| 69 |
+
distances.append(dist)
|
| 70 |
+
selected_idx = int(np.argmin(distances))
|
| 71 |
+
|
| 72 |
+
return selected_idx, predictions[selected_idx]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def run_mask_selection_stage(
|
| 76 |
+
predictions: list[dict],
|
| 77 |
+
strategy: str = "confidence_area",
|
| 78 |
+
) -> dict:
|
| 79 |
+
"""Run mask selection stage.
|
| 80 |
+
|
| 81 |
+
This stage selects the best mask from multiple YOLO predictions using
|
| 82 |
+
the specified selection strategy.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
predictions: List of prediction dicts from segmentation stage
|
| 86 |
+
strategy: Selection strategy (default: 'confidence_area')
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Stage dict with structure:
|
| 90 |
+
{
|
| 91 |
+
"stage_id": "mask_selection",
|
| 92 |
+
"stage_name": "Mask Selection",
|
| 93 |
+
"description": "Select best mask from predictions",
|
| 94 |
+
"config": {
|
| 95 |
+
"strategy": str
|
| 96 |
+
},
|
| 97 |
+
"metrics": {
|
| 98 |
+
"num_candidates": int,
|
| 99 |
+
"selected_index": int,
|
| 100 |
+
"selected_confidence": float
|
| 101 |
+
},
|
| 102 |
+
"data": {
|
| 103 |
+
"selected_prediction": dict,
|
| 104 |
+
"metadata": {
|
| 105 |
+
"strategy": str,
|
| 106 |
+
"selected_index": int,
|
| 107 |
+
"num_candidates": int,
|
| 108 |
+
"confidence": float,
|
| 109 |
+
"mask_area": int
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
ValueError: If predictions list is empty
|
| 116 |
+
"""
|
| 117 |
+
logger.info(f"Selecting best mask using strategy: {strategy}")
|
| 118 |
+
|
| 119 |
+
# Select best mask
|
| 120 |
+
selected_idx, selected_pred = select_best_mask(predictions, strategy)
|
| 121 |
+
|
| 122 |
+
# Compute metadata
|
| 123 |
+
mask_area = int(np.sum(selected_pred["mask"] > 0))
|
| 124 |
+
confidence = selected_pred["confidence"]
|
| 125 |
+
|
| 126 |
+
logger.info(
|
| 127 |
+
f"Selected mask {selected_idx} (confidence={confidence:.3f}, area={mask_area})"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Return standardized stage dict
|
| 131 |
+
return {
|
| 132 |
+
"stage_id": "mask_selection",
|
| 133 |
+
"stage_name": "Mask Selection",
|
| 134 |
+
"description": "Select best mask from predictions",
|
| 135 |
+
"config": {
|
| 136 |
+
"strategy": strategy,
|
| 137 |
+
},
|
| 138 |
+
"metrics": {
|
| 139 |
+
"num_candidates": len(predictions),
|
| 140 |
+
"selected_index": selected_idx,
|
| 141 |
+
"selected_confidence": confidence,
|
| 142 |
+
},
|
| 143 |
+
"data": {
|
| 144 |
+
"selected_prediction": selected_pred,
|
| 145 |
+
"metadata": {
|
| 146 |
+
"strategy": strategy,
|
| 147 |
+
"selected_index": selected_idx,
|
| 148 |
+
"num_candidates": len(predictions),
|
| 149 |
+
"confidence": confidence,
|
| 150 |
+
"mask_area": mask_area,
|
| 151 |
+
},
|
| 152 |
+
},
|
| 153 |
+
}
|
src/snowleopard_reid/pipeline/stages/matching.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Matching stage for snow leopard identification.
|
| 2 |
+
|
| 3 |
+
This module handles matching query features against the catalog using
|
| 4 |
+
LightGlue and computing matching metrics.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from lightglue import LightGlue
|
| 14 |
+
from scipy.stats import wasserstein_distance
|
| 15 |
+
|
| 16 |
+
from snowleopard_reid import get_device
|
| 17 |
+
from snowleopard_reid.catalog import (
|
| 18 |
+
get_all_catalog_features,
|
| 19 |
+
get_catalog_metadata_for_id,
|
| 20 |
+
get_filtered_catalog_features,
|
| 21 |
+
load_catalog_index,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_matching_stage(
|
| 28 |
+
query_features: dict[str, torch.Tensor],
|
| 29 |
+
catalog_path: Path | str,
|
| 30 |
+
top_k: int = 5,
|
| 31 |
+
extractor: str = "sift",
|
| 32 |
+
device: str | None = None,
|
| 33 |
+
query_image_path: str | None = None,
|
| 34 |
+
pairwise_output_dir: Path | None = None,
|
| 35 |
+
filter_locations: list[str] | None = None,
|
| 36 |
+
filter_body_parts: list[str] | None = None,
|
| 37 |
+
) -> dict:
|
| 38 |
+
"""Match query against catalog.
|
| 39 |
+
|
| 40 |
+
This stage matches the query features against all catalog images using
|
| 41 |
+
LightGlue, computes metrics, and ranks matches.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
query_features: Query features dict with keypoints, descriptors, scores
|
| 45 |
+
catalog_path: Path to catalog root directory (e.g., data/08_catalog/v1.0/)
|
| 46 |
+
top_k: Number of top matches to return (default: 5)
|
| 47 |
+
extractor: Feature extractor used (default: 'sift')
|
| 48 |
+
device: Device to run matching on ('cpu', 'cuda', or None for auto-detect)
|
| 49 |
+
query_image_path: Path to query image (optional, for pairwise data)
|
| 50 |
+
pairwise_output_dir: Directory to save pairwise match data (optional)
|
| 51 |
+
filter_locations: List of locations to filter catalog by (e.g., ["skycrest_valley"])
|
| 52 |
+
filter_body_parts: List of body parts to filter catalog by (e.g., ["head", "right_flank"])
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Stage dict with structure:
|
| 56 |
+
{
|
| 57 |
+
"stage_id": "matching",
|
| 58 |
+
"stage_name": "Matching",
|
| 59 |
+
"description": "Match query against catalog using LightGlue",
|
| 60 |
+
"config": {...},
|
| 61 |
+
"metrics": {...},
|
| 62 |
+
"data": {
|
| 63 |
+
"catalog_info": {...},
|
| 64 |
+
"matches": [...]
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
Raises:
|
| 69 |
+
FileNotFoundError: If catalog doesn't exist
|
| 70 |
+
ValueError: If extractor not available in catalog
|
| 71 |
+
RuntimeError: If matching fails
|
| 72 |
+
"""
|
| 73 |
+
catalog_path = Path(catalog_path)
|
| 74 |
+
|
| 75 |
+
if not catalog_path.exists():
|
| 76 |
+
raise FileNotFoundError(f"Catalog not found: {catalog_path}")
|
| 77 |
+
|
| 78 |
+
# Auto-detect device
|
| 79 |
+
device = get_device(device=device, verbose=True)
|
| 80 |
+
|
| 81 |
+
# Load catalog index
|
| 82 |
+
logger.info(f"Loading catalog from {catalog_path}")
|
| 83 |
+
catalog_index = load_catalog_index(catalog_path)
|
| 84 |
+
logger.info(
|
| 85 |
+
f"Catalog v{catalog_index['catalog_version']}: "
|
| 86 |
+
f"{catalog_index['statistics']['total_individuals']} individuals, "
|
| 87 |
+
f"{catalog_index['statistics']['total_reference_images']} images"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Load catalog features (with optional filtering)
|
| 91 |
+
if filter_locations or filter_body_parts:
|
| 92 |
+
filter_info = []
|
| 93 |
+
if filter_locations:
|
| 94 |
+
filter_info.append(f"locations={filter_locations}")
|
| 95 |
+
if filter_body_parts:
|
| 96 |
+
filter_info.append(f"body_parts={filter_body_parts}")
|
| 97 |
+
logger.info(f"Loading filtered catalog features ({', '.join(filter_info)})")
|
| 98 |
+
try:
|
| 99 |
+
catalog_features = get_filtered_catalog_features(
|
| 100 |
+
catalog_root=catalog_path,
|
| 101 |
+
extractor=extractor,
|
| 102 |
+
locations=filter_locations,
|
| 103 |
+
body_parts=filter_body_parts,
|
| 104 |
+
)
|
| 105 |
+
logger.info(f"Loaded {len(catalog_features)} filtered catalog features")
|
| 106 |
+
except ValueError as e:
|
| 107 |
+
raise ValueError(f"Failed to load filtered catalog features: {e}")
|
| 108 |
+
else:
|
| 109 |
+
logger.info(f"Loading catalog features (extractor: {extractor})")
|
| 110 |
+
try:
|
| 111 |
+
catalog_features = get_all_catalog_features(
|
| 112 |
+
catalog_root=catalog_path, extractor=extractor
|
| 113 |
+
)
|
| 114 |
+
logger.info(f"Loaded {len(catalog_features)} catalog features")
|
| 115 |
+
except ValueError as e:
|
| 116 |
+
raise ValueError(f"Failed to load catalog features: {e}")
|
| 117 |
+
|
| 118 |
+
# Initialize LightGlue matcher
|
| 119 |
+
logger.info(f"Initializing LightGlue matcher with {extractor} features")
|
| 120 |
+
try:
|
| 121 |
+
matcher = LightGlue(features=extractor).eval().to(device)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Failed to initialize LightGlue matcher with extractor '{extractor}': {e}"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Move query features to device and add batch dimension
|
| 128 |
+
query_feats = {}
|
| 129 |
+
for k, v in query_features.items():
|
| 130 |
+
if isinstance(v, torch.Tensor):
|
| 131 |
+
# Add batch dimension if not present
|
| 132 |
+
if v.ndim == 1:
|
| 133 |
+
v = v.unsqueeze(0)
|
| 134 |
+
elif v.ndim == 2:
|
| 135 |
+
v = v.unsqueeze(0)
|
| 136 |
+
query_feats[k] = v.to(device)
|
| 137 |
+
else:
|
| 138 |
+
query_feats[k] = v
|
| 139 |
+
|
| 140 |
+
# Serial matching: iterate through catalog
|
| 141 |
+
logger.info(f"Matching against {len(catalog_features)} catalog images")
|
| 142 |
+
matches_dict = {}
|
| 143 |
+
raw_matches_cache = {} # Store raw matches for pairwise saving
|
| 144 |
+
|
| 145 |
+
for catalog_id, catalog_feats in catalog_features.items():
|
| 146 |
+
# Move catalog features to device and add batch dimension
|
| 147 |
+
catalog_feats_device = {}
|
| 148 |
+
for k, v in catalog_feats.items():
|
| 149 |
+
if isinstance(v, torch.Tensor):
|
| 150 |
+
# Add batch dimension if not present
|
| 151 |
+
if v.ndim == 1:
|
| 152 |
+
v = v.unsqueeze(0)
|
| 153 |
+
elif v.ndim == 2:
|
| 154 |
+
v = v.unsqueeze(0)
|
| 155 |
+
catalog_feats_device[k] = v.to(device)
|
| 156 |
+
else:
|
| 157 |
+
catalog_feats_device[k] = v
|
| 158 |
+
|
| 159 |
+
# Run matcher
|
| 160 |
+
try:
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
matches = matcher(
|
| 163 |
+
{
|
| 164 |
+
"image0": query_feats,
|
| 165 |
+
"image1": catalog_feats_device,
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.warning(f"Matching failed for {catalog_id}: {e}")
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Compute metrics
|
| 173 |
+
try:
|
| 174 |
+
metrics = compute_match_metrics(matches)
|
| 175 |
+
matches_dict[catalog_id] = metrics
|
| 176 |
+
|
| 177 |
+
# Cache raw matches and features for top-k pairwise saving
|
| 178 |
+
if pairwise_output_dir is not None:
|
| 179 |
+
raw_matches_cache[catalog_id] = {
|
| 180 |
+
"matches": matches,
|
| 181 |
+
"catalog_features": catalog_feats,
|
| 182 |
+
}
|
| 183 |
+
except KeyError as e:
|
| 184 |
+
logger.warning(f"Failed to compute metrics for {catalog_id}: {e}")
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
logger.info(f"Successfully matched against {len(matches_dict)} catalog images")
|
| 188 |
+
|
| 189 |
+
if not matches_dict:
|
| 190 |
+
raise RuntimeError(
|
| 191 |
+
"No successful matches found. All catalog images failed to match. "
|
| 192 |
+
"This may indicate a problem with feature extraction or format."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Rank matches by Wasserstein distance
|
| 196 |
+
ranked_matches = rank_matches(matches_dict, metric="wasserstein", top_k=top_k)
|
| 197 |
+
|
| 198 |
+
# Enrich matches with catalog metadata
|
| 199 |
+
enriched_matches = []
|
| 200 |
+
for match in ranked_matches:
|
| 201 |
+
catalog_id = match["catalog_id"]
|
| 202 |
+
metadata = get_catalog_metadata_for_id(
|
| 203 |
+
catalog_root=catalog_path, catalog_id=catalog_id
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if metadata is None:
|
| 207 |
+
logger.warning(f"No metadata found for {catalog_id}")
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
enriched_match = {
|
| 211 |
+
"rank": match["rank"],
|
| 212 |
+
"catalog_id": catalog_id,
|
| 213 |
+
"leopard_name": metadata["leopard_name"],
|
| 214 |
+
"filepath": str(metadata["image_path"]),
|
| 215 |
+
"wasserstein": match["wasserstein"],
|
| 216 |
+
"auc": match["auc"],
|
| 217 |
+
"num_matches": match["num_matches"],
|
| 218 |
+
"individual_id": metadata["individual_id"],
|
| 219 |
+
}
|
| 220 |
+
enriched_matches.append(enriched_match)
|
| 221 |
+
|
| 222 |
+
if enriched_matches:
|
| 223 |
+
logger.info(
|
| 224 |
+
f"Top match: {enriched_matches[0]['leopard_name']} "
|
| 225 |
+
f"(wasserstein: {enriched_matches[0]['wasserstein']:.4f}, "
|
| 226 |
+
f"matches: {enriched_matches[0]['num_matches']})"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Save pairwise match data for top-k matches
|
| 230 |
+
if pairwise_output_dir is not None and enriched_matches:
|
| 231 |
+
pairwise_output_dir = Path(pairwise_output_dir)
|
| 232 |
+
pairwise_output_dir.mkdir(parents=True, exist_ok=True)
|
| 233 |
+
|
| 234 |
+
logger.info(
|
| 235 |
+
f"Saving pairwise match data for top-{len(enriched_matches)} matches"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Get query image size
|
| 239 |
+
query_image_size = query_features.get("image_size")
|
| 240 |
+
if isinstance(query_image_size, torch.Tensor):
|
| 241 |
+
query_image_size = query_image_size.cpu().numpy()
|
| 242 |
+
|
| 243 |
+
for enriched_match in enriched_matches:
|
| 244 |
+
catalog_id = enriched_match["catalog_id"]
|
| 245 |
+
|
| 246 |
+
# Skip if no cached data for this catalog_id
|
| 247 |
+
if catalog_id not in raw_matches_cache:
|
| 248 |
+
logger.warning(
|
| 249 |
+
f"No cached match data for {catalog_id}, skipping pairwise save"
|
| 250 |
+
)
|
| 251 |
+
enriched_match["pairwise_file"] = None
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# Get cached data
|
| 255 |
+
cached = raw_matches_cache[catalog_id]
|
| 256 |
+
matches = cached["matches"]
|
| 257 |
+
catalog_feats = cached["catalog_features"]
|
| 258 |
+
|
| 259 |
+
# Extract matched keypoints
|
| 260 |
+
try:
|
| 261 |
+
matched_data = extract_matched_keypoints(
|
| 262 |
+
query_features=query_features,
|
| 263 |
+
catalog_features=catalog_feats,
|
| 264 |
+
matches=matches,
|
| 265 |
+
)
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.warning(
|
| 268 |
+
f"Failed to extract keypoints for {catalog_id}: {e}, skipping"
|
| 269 |
+
)
|
| 270 |
+
enriched_match["pairwise_file"] = None
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
# Get catalog image size
|
| 274 |
+
catalog_image_size = catalog_feats.get("image_size")
|
| 275 |
+
if isinstance(catalog_image_size, torch.Tensor):
|
| 276 |
+
catalog_image_size = catalog_image_size.cpu().numpy()
|
| 277 |
+
|
| 278 |
+
# Build pairwise data
|
| 279 |
+
pairwise_data = {
|
| 280 |
+
"rank": enriched_match["rank"],
|
| 281 |
+
"catalog_id": catalog_id,
|
| 282 |
+
"leopard_name": enriched_match["leopard_name"],
|
| 283 |
+
"query_image_path": query_image_path or "",
|
| 284 |
+
"catalog_image_path": enriched_match["filepath"],
|
| 285 |
+
"query_image_size": query_image_size,
|
| 286 |
+
"catalog_image_size": catalog_image_size,
|
| 287 |
+
"query_keypoints": matched_data["query_keypoints"],
|
| 288 |
+
"catalog_keypoints": matched_data["catalog_keypoints"],
|
| 289 |
+
"match_scores": matched_data["match_scores"],
|
| 290 |
+
"wasserstein": enriched_match["wasserstein"],
|
| 291 |
+
"auc": enriched_match["auc"],
|
| 292 |
+
"num_matches": matched_data[
|
| 293 |
+
"num_matches"
|
| 294 |
+
], # Use actual count from extracted keypoints
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
# Save as compressed NPZ
|
| 298 |
+
output_filename = f"rank_{enriched_match['rank']:02d}_{catalog_id}.npz"
|
| 299 |
+
output_path = pairwise_output_dir / output_filename
|
| 300 |
+
|
| 301 |
+
np.savez_compressed(output_path, **pairwise_data)
|
| 302 |
+
|
| 303 |
+
# Add pairwise file reference to enriched_match (relative to matching stage dir)
|
| 304 |
+
enriched_match["pairwise_file"] = f"pairwise/{output_filename}"
|
| 305 |
+
|
| 306 |
+
logger.info(f"Saved pairwise data to {pairwise_output_dir}")
|
| 307 |
+
else:
|
| 308 |
+
# Set pairwise_file to None if not saving pairwise data
|
| 309 |
+
for enriched_match in enriched_matches:
|
| 310 |
+
enriched_match["pairwise_file"] = None
|
| 311 |
+
|
| 312 |
+
# Return standardized stage dict
|
| 313 |
+
return {
|
| 314 |
+
"stage_id": "matching",
|
| 315 |
+
"stage_name": "Matching",
|
| 316 |
+
"description": "Match query against catalog using LightGlue",
|
| 317 |
+
"config": {
|
| 318 |
+
"top_k": top_k,
|
| 319 |
+
"extractor": extractor,
|
| 320 |
+
"device": device,
|
| 321 |
+
"catalog_path": str(catalog_path),
|
| 322 |
+
"filter_locations": filter_locations,
|
| 323 |
+
"filter_body_parts": filter_body_parts,
|
| 324 |
+
},
|
| 325 |
+
"metrics": {
|
| 326 |
+
"num_catalog_images": len(catalog_features),
|
| 327 |
+
"num_successful_matches": len(matches_dict),
|
| 328 |
+
"top_match_wasserstein": enriched_matches[0]["wasserstein"]
|
| 329 |
+
if enriched_matches
|
| 330 |
+
else 0.0,
|
| 331 |
+
"top_match_leopard_name": enriched_matches[0]["leopard_name"]
|
| 332 |
+
if enriched_matches
|
| 333 |
+
else "",
|
| 334 |
+
},
|
| 335 |
+
"data": {
|
| 336 |
+
"catalog_info": {
|
| 337 |
+
"catalog_version": catalog_index["catalog_version"],
|
| 338 |
+
"catalog_path": str(catalog_path),
|
| 339 |
+
"num_individuals": catalog_index["statistics"]["total_individuals"],
|
| 340 |
+
"num_reference_images": catalog_index["statistics"][
|
| 341 |
+
"total_reference_images"
|
| 342 |
+
],
|
| 343 |
+
},
|
| 344 |
+
"matches": enriched_matches,
|
| 345 |
+
},
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# ============================================================================
|
| 350 |
+
# Metrics Utilities
|
| 351 |
+
# ============================================================================
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def compute_wasserstein_distance(scores: np.ndarray) -> float:
|
| 355 |
+
"""Compute Wasserstein distance from null distribution.
|
| 356 |
+
|
| 357 |
+
The Wasserstein distance measures how far the match score distribution is from
|
| 358 |
+
a null distribution (all zeros). Higher values indicate better matches.
|
| 359 |
+
This is the optimal metric for re-identification tasks.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
scores: Array of match scores (typically from matcher output)
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Wasserstein distance as a float
|
| 366 |
+
|
| 367 |
+
References:
|
| 368 |
+
Based on trout-reid implementation for animal re-identification
|
| 369 |
+
"""
|
| 370 |
+
if len(scores) == 0:
|
| 371 |
+
return 0.0
|
| 372 |
+
|
| 373 |
+
# Null distribution: fixed-length array of zeros
|
| 374 |
+
# This represents no matches at all
|
| 375 |
+
# Using fixed length (1024) ensures all matches are comparable
|
| 376 |
+
# to the same reference distribution (follows trout-reID implementation)
|
| 377 |
+
x_null_distribution = np.zeros(1024)
|
| 378 |
+
|
| 379 |
+
# Compute Wasserstein (Earth Mover's) distance
|
| 380 |
+
distance = wasserstein_distance(x_null_distribution, scores)
|
| 381 |
+
|
| 382 |
+
return float(distance)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def compute_auc(scores: np.ndarray) -> float:
|
| 386 |
+
"""Compute Area Under Curve (cumulative distribution) of match scores.
|
| 387 |
+
|
| 388 |
+
AUC represents the cumulative distribution of match scores.
|
| 389 |
+
Higher values indicate better matches.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
scores: Array of match scores (typically from matcher output)
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
AUC value as a float (0.0 to 1.0)
|
| 396 |
+
|
| 397 |
+
References:
|
| 398 |
+
Based on trout-reid implementation
|
| 399 |
+
"""
|
| 400 |
+
if len(scores) == 0:
|
| 401 |
+
return 0.0
|
| 402 |
+
|
| 403 |
+
# Sort scores in ascending order
|
| 404 |
+
sorted_scores = np.sort(scores)
|
| 405 |
+
|
| 406 |
+
# Compute cumulative sum
|
| 407 |
+
cumsum = np.cumsum(sorted_scores)
|
| 408 |
+
|
| 409 |
+
# Normalize by total sum to get AUC in [0, 1]
|
| 410 |
+
if cumsum[-1] > 0:
|
| 411 |
+
auc = np.trapz(cumsum / cumsum[-1]) / len(scores)
|
| 412 |
+
else:
|
| 413 |
+
auc = 0.0
|
| 414 |
+
|
| 415 |
+
return float(auc)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def extract_match_scores(matches: dict[str, torch.Tensor]) -> np.ndarray:
|
| 419 |
+
"""Extract match scores from matcher output.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
matches: Dictionary from LightGlue matcher with keys:
|
| 423 |
+
- matches0: Tensor of matched indices
|
| 424 |
+
- matching_scores0: Tensor of match confidence scores
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Numpy array of match scores
|
| 428 |
+
|
| 429 |
+
Raises:
|
| 430 |
+
KeyError: If required keys are missing from matches dict
|
| 431 |
+
"""
|
| 432 |
+
if "matching_scores0" not in matches:
|
| 433 |
+
raise KeyError("matches dictionary missing 'matching_scores0' key")
|
| 434 |
+
|
| 435 |
+
scores = matches["matching_scores0"]
|
| 436 |
+
|
| 437 |
+
# Convert to numpy and filter out invalid matches (-1 values)
|
| 438 |
+
if isinstance(scores, torch.Tensor):
|
| 439 |
+
scores = scores.cpu().numpy()
|
| 440 |
+
|
| 441 |
+
# Filter out unmatched keypoints (score = 0 or negative)
|
| 442 |
+
valid_scores = scores[scores > 0]
|
| 443 |
+
|
| 444 |
+
return valid_scores
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def extract_matched_keypoints(
|
| 448 |
+
query_features: dict[str, torch.Tensor],
|
| 449 |
+
catalog_features: dict[str, torch.Tensor],
|
| 450 |
+
matches: dict[str, torch.Tensor],
|
| 451 |
+
) -> dict[str, np.ndarray]:
|
| 452 |
+
"""Extract matched keypoint pairs from matcher output.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
query_features: Query feature dict with 'keypoints' tensor [M, 2]
|
| 456 |
+
catalog_features: Catalog feature dict with 'keypoints' tensor [N, 2]
|
| 457 |
+
matches: Dictionary from LightGlue matcher with:
|
| 458 |
+
- matches0: Tensor [M] mapping query_idx → catalog_idx (-1 if no match)
|
| 459 |
+
- matching_scores0: Tensor [M] with match confidence scores
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Dictionary with:
|
| 463 |
+
- query_keypoints: ndarray [num_matches, 2] - matched query keypoints
|
| 464 |
+
- catalog_keypoints: ndarray [num_matches, 2] - matched catalog keypoints
|
| 465 |
+
- match_scores: ndarray [num_matches] - confidence scores
|
| 466 |
+
- num_matches: int - number of valid matches
|
| 467 |
+
|
| 468 |
+
Raises:
|
| 469 |
+
KeyError: If required keys are missing
|
| 470 |
+
"""
|
| 471 |
+
if "matches0" not in matches or "matching_scores0" not in matches:
|
| 472 |
+
raise KeyError(
|
| 473 |
+
"matches dictionary missing 'matches0' or 'matching_scores0' keys"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Get match indices and scores
|
| 477 |
+
matches0 = matches["matches0"] # Shape: [M]
|
| 478 |
+
scores0 = matches["matching_scores0"] # Shape: [M]
|
| 479 |
+
|
| 480 |
+
# Convert to numpy if tensors
|
| 481 |
+
if isinstance(matches0, torch.Tensor):
|
| 482 |
+
matches0 = matches0.cpu().numpy()
|
| 483 |
+
if isinstance(scores0, torch.Tensor):
|
| 484 |
+
scores0 = scores0.cpu().numpy()
|
| 485 |
+
|
| 486 |
+
# Remove batch dimension if present
|
| 487 |
+
if matches0.ndim == 2:
|
| 488 |
+
matches0 = matches0[0]
|
| 489 |
+
if scores0.ndim == 2:
|
| 490 |
+
scores0 = scores0[0]
|
| 491 |
+
|
| 492 |
+
# Filter valid matches (matched and score > 0)
|
| 493 |
+
valid_mask = (matches0 >= 0) & (scores0 > 0)
|
| 494 |
+
valid_indices = matches0[valid_mask].astype(int)
|
| 495 |
+
valid_scores = scores0[valid_mask]
|
| 496 |
+
|
| 497 |
+
# Get keypoints
|
| 498 |
+
query_kpts = query_features["keypoints"]
|
| 499 |
+
catalog_kpts = catalog_features["keypoints"]
|
| 500 |
+
|
| 501 |
+
# Convert to numpy if tensors
|
| 502 |
+
if isinstance(query_kpts, torch.Tensor):
|
| 503 |
+
query_kpts = query_kpts.cpu().numpy()
|
| 504 |
+
if isinstance(catalog_kpts, torch.Tensor):
|
| 505 |
+
catalog_kpts = catalog_kpts.cpu().numpy()
|
| 506 |
+
|
| 507 |
+
# Remove batch dimension if present
|
| 508 |
+
if query_kpts.ndim == 3:
|
| 509 |
+
query_kpts = query_kpts[0]
|
| 510 |
+
if catalog_kpts.ndim == 3:
|
| 511 |
+
catalog_kpts = catalog_kpts[0]
|
| 512 |
+
|
| 513 |
+
# Extract matched keypoints
|
| 514 |
+
query_matched = query_kpts[valid_mask]
|
| 515 |
+
catalog_matched = catalog_kpts[valid_indices]
|
| 516 |
+
|
| 517 |
+
return {
|
| 518 |
+
"query_keypoints": query_matched,
|
| 519 |
+
"catalog_keypoints": catalog_matched,
|
| 520 |
+
"match_scores": valid_scores,
|
| 521 |
+
"num_matches": len(valid_scores),
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def rank_matches(
|
| 526 |
+
matches_dict: dict[str, dict[str, Any]],
|
| 527 |
+
metric: str = "wasserstein",
|
| 528 |
+
top_k: int = None,
|
| 529 |
+
) -> list[dict[str, Any]]:
|
| 530 |
+
"""Rank matches by specified metric.
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
matches_dict: Dictionary mapping catalog_id to match info
|
| 534 |
+
metric: Metric to rank by ('wasserstein' or 'auc')
|
| 535 |
+
top_k: Number of top matches to return (None = all)
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
List of match dictionaries sorted by metric (best first)
|
| 539 |
+
|
| 540 |
+
Raises:
|
| 541 |
+
ValueError: If metric is not supported
|
| 542 |
+
"""
|
| 543 |
+
if metric not in ["wasserstein", "auc"]:
|
| 544 |
+
raise ValueError(f"Unsupported metric: {metric}. Use 'wasserstein' or 'auc'")
|
| 545 |
+
|
| 546 |
+
# Convert dict to list with catalog_id included
|
| 547 |
+
matches_list = [
|
| 548 |
+
{"catalog_id": cid, **match_info} for cid, match_info in matches_dict.items()
|
| 549 |
+
]
|
| 550 |
+
|
| 551 |
+
# Sort by metric (descending - higher is better for both metrics)
|
| 552 |
+
sorted_matches = sorted(
|
| 553 |
+
matches_list,
|
| 554 |
+
key=lambda x: x.get(metric, 0.0),
|
| 555 |
+
reverse=True,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Add rank
|
| 559 |
+
for rank, match in enumerate(sorted_matches, start=1):
|
| 560 |
+
match["rank"] = rank
|
| 561 |
+
|
| 562 |
+
# Return top_k if specified
|
| 563 |
+
if top_k is not None:
|
| 564 |
+
return sorted_matches[:top_k]
|
| 565 |
+
|
| 566 |
+
return sorted_matches
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def compute_match_metrics(matches: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 570 |
+
"""Compute all matching metrics for a single match result.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
matches: Dictionary from LightGlue matcher
|
| 574 |
+
|
| 575 |
+
Returns:
|
| 576 |
+
Dictionary with computed metrics:
|
| 577 |
+
- wasserstein: float
|
| 578 |
+
- auc: float
|
| 579 |
+
- num_matches: int
|
| 580 |
+
|
| 581 |
+
Raises:
|
| 582 |
+
KeyError: If matches dict is missing required keys
|
| 583 |
+
"""
|
| 584 |
+
try:
|
| 585 |
+
scores = extract_match_scores(matches)
|
| 586 |
+
|
| 587 |
+
return {
|
| 588 |
+
"wasserstein": compute_wasserstein_distance(scores),
|
| 589 |
+
"auc": compute_auc(scores),
|
| 590 |
+
"num_matches": len(scores),
|
| 591 |
+
}
|
| 592 |
+
except KeyError as e:
|
| 593 |
+
raise KeyError(f"Failed to compute metrics: {e}")
|
src/snowleopard_reid/pipeline/stages/preprocess.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocessing stage for cropping and masking snow leopard images.
|
| 2 |
+
|
| 3 |
+
This module provides preprocessing operations to extract and mask the leopard region
|
| 4 |
+
from the full image.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def run_preprocess_stage(
|
| 18 |
+
image_path: Path | str,
|
| 19 |
+
mask: np.ndarray,
|
| 20 |
+
padding: int = 5,
|
| 21 |
+
) -> dict:
|
| 22 |
+
"""Run preprocessing stage.
|
| 23 |
+
|
| 24 |
+
This stage crops the image to the mask bounding box with padding and applies
|
| 25 |
+
the mask to isolate the leopard region.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
image_path: Path to input image
|
| 29 |
+
mask: Binary mask (H×W, uint8) from segmentation
|
| 30 |
+
padding: Padding around bbox in pixels (default: 5)
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Stage dict with structure:
|
| 34 |
+
{
|
| 35 |
+
"stage_id": "preprocessing",
|
| 36 |
+
"stage_name": "Preprocessing",
|
| 37 |
+
"description": "Crop and mask leopard region",
|
| 38 |
+
"config": {
|
| 39 |
+
"padding": int
|
| 40 |
+
},
|
| 41 |
+
"metrics": {
|
| 42 |
+
"original_size": {"width": int, "height": int},
|
| 43 |
+
"crop_size": {"width": int, "height": int}
|
| 44 |
+
},
|
| 45 |
+
"data": {
|
| 46 |
+
"cropped_image": PIL.Image,
|
| 47 |
+
"metadata": {
|
| 48 |
+
"original_size": {"width": int, "height": int},
|
| 49 |
+
"crop_bbox": {"x_min": int, "y_min": int, "x_max": int, "y_max": int},
|
| 50 |
+
"crop_size": {"width": int, "height": int}
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
FileNotFoundError: If image doesn't exist
|
| 57 |
+
ValueError: If mask is invalid
|
| 58 |
+
"""
|
| 59 |
+
image_path = Path(image_path)
|
| 60 |
+
|
| 61 |
+
if not image_path.exists():
|
| 62 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 63 |
+
|
| 64 |
+
logger.info(f"Preprocessing image: {image_path}")
|
| 65 |
+
|
| 66 |
+
# Load image
|
| 67 |
+
image = cv2.imread(str(image_path))
|
| 68 |
+
if image is None:
|
| 69 |
+
raise RuntimeError(f"Failed to load image: {image_path}")
|
| 70 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 71 |
+
image_height, image_width = image_rgb.shape[:2]
|
| 72 |
+
|
| 73 |
+
# Resize mask to match image dimensions if needed
|
| 74 |
+
if mask.shape[:2] != (image_height, image_width):
|
| 75 |
+
mask_resized = cv2.resize(
|
| 76 |
+
mask.astype(np.uint8),
|
| 77 |
+
(image_width, image_height),
|
| 78 |
+
interpolation=cv2.INTER_NEAREST,
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
mask_resized = mask
|
| 82 |
+
|
| 83 |
+
# Find bounding box of mask
|
| 84 |
+
rows = np.any(mask_resized > 0, axis=1)
|
| 85 |
+
cols = np.any(mask_resized > 0, axis=0)
|
| 86 |
+
|
| 87 |
+
if not np.any(rows) or not np.any(cols):
|
| 88 |
+
raise ValueError("Mask is empty (no pixels > 0)")
|
| 89 |
+
|
| 90 |
+
y_min, y_max = np.where(rows)[0][[0, -1]]
|
| 91 |
+
x_min, x_max = np.where(cols)[0][[0, -1]]
|
| 92 |
+
|
| 93 |
+
# Add padding
|
| 94 |
+
x_min = max(0, x_min - padding)
|
| 95 |
+
y_min = max(0, y_min - padding)
|
| 96 |
+
x_max = min(image_width - 1, x_max + padding)
|
| 97 |
+
y_max = min(image_height - 1, y_max + padding)
|
| 98 |
+
|
| 99 |
+
# Crop image and mask
|
| 100 |
+
cropped_image = image_rgb[y_min : y_max + 1, x_min : x_max + 1]
|
| 101 |
+
cropped_mask = mask_resized[y_min : y_max + 1, x_min : x_max + 1]
|
| 102 |
+
|
| 103 |
+
# Apply mask (set non-masked pixels to black)
|
| 104 |
+
masked_image = cropped_image.copy()
|
| 105 |
+
masked_image[cropped_mask == 0] = 0
|
| 106 |
+
|
| 107 |
+
# Convert to PIL Image
|
| 108 |
+
cropped_pil = Image.fromarray(masked_image)
|
| 109 |
+
|
| 110 |
+
crop_height, crop_width = masked_image.shape[:2]
|
| 111 |
+
|
| 112 |
+
logger.info(
|
| 113 |
+
f"Cropped from {image_width}x{image_height} to {crop_width}x{crop_height} "
|
| 114 |
+
f"(padding={padding}px)"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Return standardized stage dict
|
| 118 |
+
return {
|
| 119 |
+
"stage_id": "preprocessing",
|
| 120 |
+
"stage_name": "Preprocessing",
|
| 121 |
+
"description": "Crop and mask leopard region",
|
| 122 |
+
"config": {
|
| 123 |
+
"padding": padding,
|
| 124 |
+
},
|
| 125 |
+
"metrics": {
|
| 126 |
+
"original_size": {"width": image_width, "height": image_height},
|
| 127 |
+
"crop_size": {"width": crop_width, "height": crop_height},
|
| 128 |
+
},
|
| 129 |
+
"data": {
|
| 130 |
+
"cropped_image": cropped_pil,
|
| 131 |
+
"metadata": {
|
| 132 |
+
"original_size": {"width": image_width, "height": image_height},
|
| 133 |
+
"crop_bbox": {
|
| 134 |
+
"x_min": int(x_min),
|
| 135 |
+
"y_min": int(y_min),
|
| 136 |
+
"x_max": int(x_max),
|
| 137 |
+
"y_max": int(y_max),
|
| 138 |
+
},
|
| 139 |
+
"crop_size": {"width": crop_width, "height": crop_height},
|
| 140 |
+
},
|
| 141 |
+
},
|
| 142 |
+
}
|
src/snowleopard_reid/pipeline/stages/segmentation.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Segmentation stage using YOLO or GDINO+SAM for snow leopard detection.
|
| 2 |
+
|
| 3 |
+
This module provides segmentation stages that detect and segment snow leopards
|
| 4 |
+
in query images using either:
|
| 5 |
+
1. YOLO (end-to-end learned segmentation)
|
| 6 |
+
2. GDINO+SAM (zero-shot detection + prompted segmentation)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from segment_anything_hq import SamPredictor, sam_model_registry
|
| 18 |
+
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
| 19 |
+
from ultralytics import YOLO
|
| 20 |
+
|
| 21 |
+
from snowleopard_reid import get_device
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_gdino_model(
|
| 27 |
+
model_id: str = "IDEA-Research/grounding-dino-base",
|
| 28 |
+
device: str | None = None,
|
| 29 |
+
) -> tuple[Any, Any]:
|
| 30 |
+
"""Load Grounding DINO model and processor.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_id: HuggingFace model identifier
|
| 34 |
+
device: Device to load model on (None = auto-detect)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple of (processor, model)
|
| 38 |
+
"""
|
| 39 |
+
device = get_device(device=device, verbose=True)
|
| 40 |
+
|
| 41 |
+
logger.info(f"Loading Grounding DINO model: {model_id}")
|
| 42 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 43 |
+
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
| 44 |
+
model = model.to(device)
|
| 45 |
+
model.eval()
|
| 46 |
+
|
| 47 |
+
logger.info("Grounding DINO model loaded successfully")
|
| 48 |
+
return processor, model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_sam_predictor(
|
| 52 |
+
checkpoint_path: Path | str,
|
| 53 |
+
model_type: str = "vit_l",
|
| 54 |
+
device: str | None = None,
|
| 55 |
+
) -> SamPredictor:
|
| 56 |
+
"""Load SAM HQ predictor.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
checkpoint_path: Path to SAM HQ checkpoint file
|
| 60 |
+
model_type: Model type (vit_b, vit_l, vit_h)
|
| 61 |
+
device: Device to load model on (None = auto-detect)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
SamPredictor instance
|
| 65 |
+
"""
|
| 66 |
+
checkpoint_path = Path(checkpoint_path)
|
| 67 |
+
if not checkpoint_path.exists():
|
| 68 |
+
raise FileNotFoundError(f"SAM checkpoint not found: {checkpoint_path}")
|
| 69 |
+
|
| 70 |
+
device_str = get_device(device=device, verbose=True)
|
| 71 |
+
|
| 72 |
+
logger.info(f"Loading SAM HQ model: {model_type}")
|
| 73 |
+
sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
|
| 74 |
+
sam.to(device=device_str)
|
| 75 |
+
|
| 76 |
+
predictor = SamPredictor(sam)
|
| 77 |
+
logger.info("SAM HQ model loaded successfully")
|
| 78 |
+
|
| 79 |
+
return predictor
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _run_yolo_segmentation(
|
| 83 |
+
model: YOLO,
|
| 84 |
+
image_path: Path,
|
| 85 |
+
confidence_threshold: float,
|
| 86 |
+
device: str,
|
| 87 |
+
) -> dict:
|
| 88 |
+
"""Run YOLO segmentation (internal implementation).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model: Pre-loaded YOLO model
|
| 92 |
+
image_path: Path to input image
|
| 93 |
+
confidence_threshold: Minimum confidence to keep predictions
|
| 94 |
+
device: Device to run on
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Standardized stage dict
|
| 98 |
+
"""
|
| 99 |
+
# Load image to get size
|
| 100 |
+
image = cv2.imread(str(image_path))
|
| 101 |
+
if image is None:
|
| 102 |
+
raise RuntimeError(f"Failed to load image: {image_path}")
|
| 103 |
+
image_height, image_width = image.shape[:2]
|
| 104 |
+
|
| 105 |
+
# Run inference
|
| 106 |
+
try:
|
| 107 |
+
results = model(
|
| 108 |
+
str(image_path),
|
| 109 |
+
conf=confidence_threshold,
|
| 110 |
+
device=device,
|
| 111 |
+
verbose=False,
|
| 112 |
+
)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
raise RuntimeError(f"YOLO inference failed: {e}")
|
| 115 |
+
|
| 116 |
+
# Parse results
|
| 117 |
+
predictions = []
|
| 118 |
+
result = results[0] # Single image, so single result
|
| 119 |
+
|
| 120 |
+
# Debug: Print result attributes
|
| 121 |
+
logger.info(f"Result object type: {type(result)}")
|
| 122 |
+
logger.info(f"Result has boxes: {result.boxes is not None}")
|
| 123 |
+
logger.info(f"Result has masks: {result.masks is not None}")
|
| 124 |
+
if result.boxes is not None:
|
| 125 |
+
logger.info(f"Number of boxes: {len(result.boxes)}")
|
| 126 |
+
if result.masks is not None:
|
| 127 |
+
logger.info(f"Number of masks: {len(result.masks)}")
|
| 128 |
+
|
| 129 |
+
# Check if any detections found
|
| 130 |
+
if result.masks is None or len(result.masks) == 0:
|
| 131 |
+
logger.warning(f"No detections found for {image_path}")
|
| 132 |
+
logger.warning(
|
| 133 |
+
f"Boxes present: {result.boxes is not None}, Masks present: {result.masks is not None}"
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
# Extract masks and metadata
|
| 137 |
+
for idx in range(len(result.masks)):
|
| 138 |
+
# Get mask (binary, H×W)
|
| 139 |
+
mask = result.masks.data[idx].cpu().numpy() # Shape: (H, W)
|
| 140 |
+
mask = (mask * 255).astype(np.uint8) # Convert to 0-255
|
| 141 |
+
|
| 142 |
+
# Get bounding box (normalized xywh format)
|
| 143 |
+
bbox = result.boxes.xywhn[idx].cpu().numpy() # Shape: (4,)
|
| 144 |
+
x_center, y_center, width, height = bbox
|
| 145 |
+
|
| 146 |
+
# Get confidence
|
| 147 |
+
confidence = float(result.boxes.conf[idx].cpu().numpy())
|
| 148 |
+
|
| 149 |
+
# Get class info
|
| 150 |
+
class_id = int(result.boxes.cls[idx].cpu().numpy())
|
| 151 |
+
class_name = result.names[class_id]
|
| 152 |
+
|
| 153 |
+
predictions.append(
|
| 154 |
+
{
|
| 155 |
+
"mask": mask,
|
| 156 |
+
"confidence": confidence,
|
| 157 |
+
"bbox_xywhn": {
|
| 158 |
+
"x_center": float(x_center),
|
| 159 |
+
"y_center": float(y_center),
|
| 160 |
+
"width": float(width),
|
| 161 |
+
"height": float(height),
|
| 162 |
+
},
|
| 163 |
+
"class_id": class_id,
|
| 164 |
+
"class_name": class_name,
|
| 165 |
+
}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
logger.info(
|
| 169 |
+
f"Found {len(predictions)} predictions (confidence >= {confidence_threshold})"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Return standardized stage dict
|
| 173 |
+
return {
|
| 174 |
+
"stage_id": "segmentation",
|
| 175 |
+
"stage_name": "YOLO Segmentation",
|
| 176 |
+
"description": "Snow leopard detection and segmentation using YOLO",
|
| 177 |
+
"config": {
|
| 178 |
+
"strategy": "yolo",
|
| 179 |
+
"confidence_threshold": confidence_threshold,
|
| 180 |
+
"device": device,
|
| 181 |
+
},
|
| 182 |
+
"metrics": {
|
| 183 |
+
"num_predictions": len(predictions),
|
| 184 |
+
},
|
| 185 |
+
"data": {
|
| 186 |
+
"image_path": str(image_path),
|
| 187 |
+
"image_size": {"width": image_width, "height": image_height},
|
| 188 |
+
"predictions": predictions,
|
| 189 |
+
},
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _run_gdino_sam_segmentation(
|
| 194 |
+
gdino_processor: Any,
|
| 195 |
+
gdino_model: Any,
|
| 196 |
+
sam_predictor: SamPredictor,
|
| 197 |
+
image_path: Path,
|
| 198 |
+
confidence_threshold: float,
|
| 199 |
+
text_prompt: str,
|
| 200 |
+
box_threshold: float,
|
| 201 |
+
text_threshold: float,
|
| 202 |
+
device: str,
|
| 203 |
+
) -> dict:
|
| 204 |
+
"""Run GDINO+SAM segmentation (internal implementation).
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
gdino_processor: Grounding DINO processor
|
| 208 |
+
gdino_model: Grounding DINO model
|
| 209 |
+
sam_predictor: SAM HQ predictor
|
| 210 |
+
image_path: Path to input image
|
| 211 |
+
confidence_threshold: Minimum confidence to keep predictions
|
| 212 |
+
text_prompt: Text prompt for GDINO
|
| 213 |
+
box_threshold: GDINO box threshold
|
| 214 |
+
text_threshold: GDINO text threshold
|
| 215 |
+
device: Device to run on
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Standardized stage dict
|
| 219 |
+
"""
|
| 220 |
+
# Load image (PIL for GDINO, numpy for SAM)
|
| 221 |
+
image_pil = Image.open(image_path).convert("RGB")
|
| 222 |
+
image_np = np.array(image_pil)
|
| 223 |
+
image_height, image_width = image_np.shape[:2]
|
| 224 |
+
|
| 225 |
+
# Run Grounding DINO detection
|
| 226 |
+
logger.info("Running Grounding DINO detection...")
|
| 227 |
+
inputs = gdino_processor(images=image_pil, text=text_prompt, return_tensors="pt")
|
| 228 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
outputs = gdino_model(**inputs)
|
| 232 |
+
|
| 233 |
+
# Post-process GDINO outputs
|
| 234 |
+
results = gdino_processor.post_process_grounded_object_detection(
|
| 235 |
+
outputs,
|
| 236 |
+
inputs["input_ids"],
|
| 237 |
+
threshold=box_threshold,
|
| 238 |
+
text_threshold=text_threshold,
|
| 239 |
+
target_sizes=[image_pil.size[::-1]], # (height, width)
|
| 240 |
+
)[0]
|
| 241 |
+
|
| 242 |
+
# Filter by confidence threshold
|
| 243 |
+
labels = results.get("text_labels", results.get("labels", []))
|
| 244 |
+
boxes = results["boxes"]
|
| 245 |
+
scores = results["scores"]
|
| 246 |
+
|
| 247 |
+
logger.info(f"GDINO detected {len(boxes)} objects")
|
| 248 |
+
|
| 249 |
+
# Filter predictions by confidence threshold
|
| 250 |
+
filtered_detections = [
|
| 251 |
+
(box, score, label)
|
| 252 |
+
for box, score, label in zip(boxes, scores, labels)
|
| 253 |
+
if float(score) >= confidence_threshold
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
logger.info(
|
| 257 |
+
f"Filtered to {len(filtered_detections)} detections (confidence >= {confidence_threshold})"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if not filtered_detections:
|
| 261 |
+
logger.warning(f"No detections found for {image_path}")
|
| 262 |
+
predictions = []
|
| 263 |
+
else:
|
| 264 |
+
# Set image for SAM (do this once)
|
| 265 |
+
logger.info("Running SAM HQ segmentation...")
|
| 266 |
+
sam_predictor.set_image(image_np)
|
| 267 |
+
|
| 268 |
+
predictions = []
|
| 269 |
+
for idx, (box, gdino_score, label) in enumerate(filtered_detections):
|
| 270 |
+
# Convert box to pixel coordinates and format for SAM
|
| 271 |
+
x_min, y_min, x_max, y_max = box
|
| 272 |
+
bbox_xyxy = np.array(
|
| 273 |
+
[float(x_min), float(y_min), float(x_max), float(y_max)]
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Run SAM with bounding box prompt
|
| 277 |
+
masks, sam_scores, logits = sam_predictor.predict(
|
| 278 |
+
box=bbox_xyxy[None, :],
|
| 279 |
+
multimask_output=False,
|
| 280 |
+
hq_token_only=True,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Get mask (first and only mask, since multimask_output=False)
|
| 284 |
+
mask = masks[0] # Shape: (H, W), boolean
|
| 285 |
+
sam_score = float(sam_scores[0])
|
| 286 |
+
|
| 287 |
+
# Convert mask to uint8 (0-255)
|
| 288 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 289 |
+
|
| 290 |
+
# Convert bbox to normalized xywh format (same as YOLO)
|
| 291 |
+
x_center = (float(x_min) + float(x_max)) / 2 / image_width
|
| 292 |
+
y_center = (float(y_min) + float(y_max)) / 2 / image_height
|
| 293 |
+
width = (float(x_max) - float(x_min)) / image_width
|
| 294 |
+
height = (float(y_max) - float(y_min)) / image_height
|
| 295 |
+
|
| 296 |
+
predictions.append(
|
| 297 |
+
{
|
| 298 |
+
"mask": mask_uint8,
|
| 299 |
+
"confidence": float(
|
| 300 |
+
gdino_score
|
| 301 |
+
), # Use GDINO score as primary confidence
|
| 302 |
+
"bbox_xywhn": {
|
| 303 |
+
"x_center": x_center,
|
| 304 |
+
"y_center": y_center,
|
| 305 |
+
"width": width,
|
| 306 |
+
"height": height,
|
| 307 |
+
},
|
| 308 |
+
"class_id": 0, # Single class (snow leopard)
|
| 309 |
+
"class_name": label,
|
| 310 |
+
# Additional metadata
|
| 311 |
+
"sam_score": sam_score,
|
| 312 |
+
"gdino_score": float(gdino_score),
|
| 313 |
+
}
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
logger.info(f"Generated {len(predictions)} segmentation masks")
|
| 317 |
+
|
| 318 |
+
# Return standardized stage dict
|
| 319 |
+
return {
|
| 320 |
+
"stage_id": "segmentation",
|
| 321 |
+
"stage_name": "GDINO+SAM Segmentation",
|
| 322 |
+
"description": "Snow leopard detection using Grounding DINO and segmentation using SAM HQ",
|
| 323 |
+
"config": {
|
| 324 |
+
"strategy": "gdino_sam",
|
| 325 |
+
"confidence_threshold": confidence_threshold,
|
| 326 |
+
"text_prompt": text_prompt,
|
| 327 |
+
"box_threshold": box_threshold,
|
| 328 |
+
"text_threshold": text_threshold,
|
| 329 |
+
"device": device,
|
| 330 |
+
},
|
| 331 |
+
"metrics": {
|
| 332 |
+
"num_predictions": len(predictions),
|
| 333 |
+
},
|
| 334 |
+
"data": {
|
| 335 |
+
"image_path": str(image_path),
|
| 336 |
+
"image_size": {"width": image_width, "height": image_height},
|
| 337 |
+
"predictions": predictions,
|
| 338 |
+
},
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def run_segmentation_stage(
|
| 343 |
+
image_path: Path | str,
|
| 344 |
+
strategy: str = "yolo",
|
| 345 |
+
confidence_threshold: float = 0.5,
|
| 346 |
+
device: str | None = None,
|
| 347 |
+
# YOLO-specific parameters
|
| 348 |
+
yolo_model: YOLO | None = None,
|
| 349 |
+
# GDINO+SAM-specific parameters
|
| 350 |
+
gdino_processor: Any | None = None,
|
| 351 |
+
gdino_model: Any | None = None,
|
| 352 |
+
sam_predictor: SamPredictor | None = None,
|
| 353 |
+
text_prompt: str = "a snow leopard.",
|
| 354 |
+
box_threshold: float = 0.30,
|
| 355 |
+
text_threshold: float = 0.20,
|
| 356 |
+
) -> dict:
|
| 357 |
+
"""Run segmentation on query image using specified strategy.
|
| 358 |
+
|
| 359 |
+
This stage performs snow leopard detection and segmentation using either:
|
| 360 |
+
- YOLO: End-to-end learned segmentation
|
| 361 |
+
- GDINO+SAM: Zero-shot detection + prompted segmentation
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
image_path: Path to input image
|
| 365 |
+
strategy: Segmentation strategy ("yolo" or "gdino_sam")
|
| 366 |
+
confidence_threshold: Minimum confidence to keep predictions (default: 0.5)
|
| 367 |
+
device: Device to run on ('cpu', 'cuda', or None for auto-detect)
|
| 368 |
+
yolo_model: Pre-loaded YOLO model (required if strategy="yolo")
|
| 369 |
+
gdino_processor: Pre-loaded GDINO processor (required if strategy="gdino_sam")
|
| 370 |
+
gdino_model: Pre-loaded GDINO model (required if strategy="gdino_sam")
|
| 371 |
+
sam_predictor: Pre-loaded SAM predictor (required if strategy="gdino_sam")
|
| 372 |
+
text_prompt: Text prompt for GDINO (default: "a snow leopard.")
|
| 373 |
+
box_threshold: GDINO box confidence threshold (default: 0.30)
|
| 374 |
+
text_threshold: GDINO text matching threshold (default: 0.20)
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Stage dict with structure:
|
| 378 |
+
{
|
| 379 |
+
"stage_id": "segmentation",
|
| 380 |
+
"stage_name": str,
|
| 381 |
+
"description": str,
|
| 382 |
+
"config": {
|
| 383 |
+
"strategy": str,
|
| 384 |
+
"confidence_threshold": float,
|
| 385 |
+
"device": str,
|
| 386 |
+
...
|
| 387 |
+
},
|
| 388 |
+
"metrics": {
|
| 389 |
+
"num_predictions": int
|
| 390 |
+
},
|
| 391 |
+
"data": {
|
| 392 |
+
"image_path": str,
|
| 393 |
+
"image_size": {"width": int, "height": int},
|
| 394 |
+
"predictions": [
|
| 395 |
+
{
|
| 396 |
+
"mask": np.ndarray (H×W, uint8),
|
| 397 |
+
"confidence": float,
|
| 398 |
+
"bbox_xywhn": {...},
|
| 399 |
+
"class_id": int,
|
| 400 |
+
"class_name": str,
|
| 401 |
+
# Optional (GDINO+SAM only)
|
| 402 |
+
"sam_score": float,
|
| 403 |
+
"gdino_score": float,
|
| 404 |
+
},
|
| 405 |
+
...
|
| 406 |
+
]
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
Raises:
|
| 411 |
+
ValueError: If strategy is invalid or required models are missing
|
| 412 |
+
FileNotFoundError: If image doesn't exist
|
| 413 |
+
RuntimeError: If inference fails
|
| 414 |
+
"""
|
| 415 |
+
image_path = Path(image_path)
|
| 416 |
+
|
| 417 |
+
# Validate inputs
|
| 418 |
+
if not image_path.exists():
|
| 419 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 420 |
+
|
| 421 |
+
if strategy not in ["yolo", "gdino_sam"]:
|
| 422 |
+
raise ValueError(f"Invalid strategy: {strategy}. Must be 'yolo' or 'gdino_sam'")
|
| 423 |
+
|
| 424 |
+
# Auto-detect device if not specified
|
| 425 |
+
device = get_device(device=device, verbose=True)
|
| 426 |
+
|
| 427 |
+
# Dispatch to appropriate implementation
|
| 428 |
+
if strategy == "yolo":
|
| 429 |
+
if yolo_model is None:
|
| 430 |
+
raise ValueError("yolo_model is required when strategy='yolo'")
|
| 431 |
+
return _run_yolo_segmentation(
|
| 432 |
+
model=yolo_model,
|
| 433 |
+
image_path=image_path,
|
| 434 |
+
confidence_threshold=confidence_threshold,
|
| 435 |
+
device=device,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
elif strategy == "gdino_sam":
|
| 439 |
+
if gdino_processor is None or gdino_model is None or sam_predictor is None:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
"gdino_processor, gdino_model, and sam_predictor are required when strategy='gdino_sam'"
|
| 442 |
+
)
|
| 443 |
+
return _run_gdino_sam_segmentation(
|
| 444 |
+
gdino_processor=gdino_processor,
|
| 445 |
+
gdino_model=gdino_model,
|
| 446 |
+
sam_predictor=sam_predictor,
|
| 447 |
+
image_path=image_path,
|
| 448 |
+
confidence_threshold=confidence_threshold,
|
| 449 |
+
text_prompt=text_prompt,
|
| 450 |
+
box_threshold=box_threshold,
|
| 451 |
+
text_threshold=text_threshold,
|
| 452 |
+
device=device,
|
| 453 |
+
)
|
src/snowleopard_reid/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module provides common utilities used across the project.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_device(device: str | None = None, verbose: bool = True) -> str:
|
| 14 |
+
"""Get the device to use for computation.
|
| 15 |
+
|
| 16 |
+
Auto-detects GPU if available, or uses CPU as fallback.
|
| 17 |
+
Optionally allows manual override.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
device: Manual device override ('cpu', 'cuda', or None for auto-detect)
|
| 21 |
+
verbose: Whether to log device information
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Device string ('cuda' or 'cpu')
|
| 25 |
+
|
| 26 |
+
Examples:
|
| 27 |
+
>>> device = get_device() # Auto-detect
|
| 28 |
+
>>> device = get_device('cpu') # Force CPU
|
| 29 |
+
>>> device = get_device('cuda') # Force CUDA (will fail if not available)
|
| 30 |
+
"""
|
| 31 |
+
if device is None:
|
| 32 |
+
# Auto-detect
|
| 33 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
|
| 35 |
+
if verbose:
|
| 36 |
+
if device == "cuda":
|
| 37 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 38 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (
|
| 39 |
+
1024**3
|
| 40 |
+
)
|
| 41 |
+
logger.info(f"Using GPU: {gpu_name} ({gpu_memory:.1f} GB)")
|
| 42 |
+
else:
|
| 43 |
+
logger.info("Using CPU (no GPU available)")
|
| 44 |
+
else:
|
| 45 |
+
# Manual override
|
| 46 |
+
device = device.lower()
|
| 47 |
+
if device not in ["cpu", "cuda"]:
|
| 48 |
+
raise ValueError(f"Invalid device: {device}. Must be 'cpu' or 'cuda'")
|
| 49 |
+
|
| 50 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
"CUDA device requested but CUDA is not available. "
|
| 53 |
+
"Install CUDA-enabled PyTorch or use device='cpu'"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if verbose:
|
| 57 |
+
logger.info(f"Using device: {device}")
|
| 58 |
+
|
| 59 |
+
return device
|
src/snowleopard_reid/visualization.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visualization utilities for snow leopard re-identification.
|
| 2 |
+
|
| 3 |
+
This module provides functions for visualizing keypoints, matches, and other
|
| 4 |
+
pipeline outputs for debugging and presentation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def draw_keypoints_overlay(
|
| 15 |
+
image_path: Path | str,
|
| 16 |
+
keypoints: np.ndarray,
|
| 17 |
+
max_keypoints: int = 500,
|
| 18 |
+
color: str = "blue",
|
| 19 |
+
ps: int = 10,
|
| 20 |
+
) -> Image.Image:
|
| 21 |
+
"""Draw keypoints overlaid on an image.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
image_path: Path to image file
|
| 25 |
+
keypoints: Keypoints array of shape [N, 2] with (x, y) coordinates
|
| 26 |
+
max_keypoints: Maximum number of keypoints to display
|
| 27 |
+
color: Color name ('blue', 'red', 'green', etc.)
|
| 28 |
+
ps: Point size for keypoints
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
PIL Image with keypoints drawn
|
| 32 |
+
"""
|
| 33 |
+
# Load image
|
| 34 |
+
img = cv2.imread(str(image_path))
|
| 35 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 36 |
+
|
| 37 |
+
# Color mapping
|
| 38 |
+
color_map = {
|
| 39 |
+
"blue": (0, 0, 255),
|
| 40 |
+
"red": (255, 0, 0),
|
| 41 |
+
"green": (0, 255, 0),
|
| 42 |
+
"yellow": (255, 255, 0),
|
| 43 |
+
"cyan": (0, 255, 255),
|
| 44 |
+
"magenta": (255, 0, 255),
|
| 45 |
+
}
|
| 46 |
+
color_rgb = color_map.get(color.lower(), (0, 0, 255))
|
| 47 |
+
|
| 48 |
+
# Draw keypoints (limit to max_keypoints)
|
| 49 |
+
n_keypoints = min(len(keypoints), max_keypoints)
|
| 50 |
+
for i in range(n_keypoints):
|
| 51 |
+
x, y = keypoints[i]
|
| 52 |
+
cv2.circle(img_rgb, (int(x), int(y)), ps // 2, color_rgb, -1)
|
| 53 |
+
|
| 54 |
+
return Image.fromarray(img_rgb)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def draw_matched_keypoints(
|
| 58 |
+
query_image_path: Path | str,
|
| 59 |
+
catalog_image_path: Path | str,
|
| 60 |
+
query_keypoints: np.ndarray,
|
| 61 |
+
catalog_keypoints: np.ndarray,
|
| 62 |
+
match_scores: np.ndarray,
|
| 63 |
+
max_matches: int = 100,
|
| 64 |
+
) -> Image.Image:
|
| 65 |
+
"""Draw matched keypoints side-by-side with connecting lines.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
query_image_path: Path to query image
|
| 69 |
+
catalog_image_path: Path to catalog image
|
| 70 |
+
query_keypoints: Query keypoints [N, 2]
|
| 71 |
+
catalog_keypoints: Catalog keypoints [N, 2]
|
| 72 |
+
match_scores: Match confidence scores [N]
|
| 73 |
+
max_matches: Maximum number of matches to display
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
PIL Image with side-by-side images and match lines
|
| 77 |
+
"""
|
| 78 |
+
# Load images
|
| 79 |
+
query_img = cv2.imread(str(query_image_path))
|
| 80 |
+
catalog_img = cv2.imread(str(catalog_image_path))
|
| 81 |
+
|
| 82 |
+
# Convert to RGB
|
| 83 |
+
query_rgb = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
|
| 84 |
+
catalog_rgb = cv2.cvtColor(catalog_img, cv2.COLOR_BGR2RGB)
|
| 85 |
+
|
| 86 |
+
# Resize images to same height for side-by-side display
|
| 87 |
+
max_height = 800
|
| 88 |
+
query_h, query_w = query_rgb.shape[:2]
|
| 89 |
+
catalog_h, catalog_w = catalog_rgb.shape[:2]
|
| 90 |
+
|
| 91 |
+
# Calculate scaling factors
|
| 92 |
+
if query_h > max_height or catalog_h > max_height:
|
| 93 |
+
query_scale = max_height / query_h
|
| 94 |
+
catalog_scale = max_height / catalog_h
|
| 95 |
+
else:
|
| 96 |
+
query_scale = 1.0
|
| 97 |
+
catalog_scale = 1.0
|
| 98 |
+
|
| 99 |
+
# Resize images
|
| 100 |
+
new_query_h = int(query_h * query_scale)
|
| 101 |
+
new_query_w = int(query_w * query_scale)
|
| 102 |
+
new_catalog_h = int(catalog_h * catalog_scale)
|
| 103 |
+
new_catalog_w = int(catalog_w * catalog_scale)
|
| 104 |
+
|
| 105 |
+
query_resized = cv2.resize(query_rgb, (new_query_w, new_query_h))
|
| 106 |
+
catalog_resized = cv2.resize(catalog_rgb, (new_catalog_w, new_catalog_h))
|
| 107 |
+
|
| 108 |
+
# Scale keypoints
|
| 109 |
+
query_kpts_scaled = query_keypoints * query_scale
|
| 110 |
+
catalog_kpts_scaled = catalog_keypoints * catalog_scale
|
| 111 |
+
|
| 112 |
+
# Create side-by-side canvas
|
| 113 |
+
combined_h = max(new_query_h, new_catalog_h)
|
| 114 |
+
combined_w = new_query_w + new_catalog_w
|
| 115 |
+
canvas = np.zeros((combined_h, combined_w, 3), dtype=np.uint8)
|
| 116 |
+
|
| 117 |
+
# Place images on canvas
|
| 118 |
+
canvas[:new_query_h, :new_query_w] = query_resized
|
| 119 |
+
canvas[:new_catalog_h, new_query_w : new_query_w + new_catalog_w] = catalog_resized
|
| 120 |
+
|
| 121 |
+
# Offset catalog keypoints to account for horizontal placement
|
| 122 |
+
catalog_kpts_offset = catalog_kpts_scaled.copy()
|
| 123 |
+
catalog_kpts_offset[:, 0] += new_query_w
|
| 124 |
+
|
| 125 |
+
# Draw matches (limit to max_matches)
|
| 126 |
+
n_matches = min(len(query_kpts_scaled), max_matches)
|
| 127 |
+
|
| 128 |
+
# Sort by match scores (highest confidence first)
|
| 129 |
+
if len(match_scores) > 0:
|
| 130 |
+
sorted_indices = np.argsort(match_scores)[::-1][:n_matches]
|
| 131 |
+
else:
|
| 132 |
+
sorted_indices = np.arange(n_matches)
|
| 133 |
+
|
| 134 |
+
# Draw lines and keypoints
|
| 135 |
+
for idx in sorted_indices:
|
| 136 |
+
query_pt = tuple(query_kpts_scaled[idx].astype(int))
|
| 137 |
+
catalog_pt = tuple(catalog_kpts_offset[idx].astype(int))
|
| 138 |
+
|
| 139 |
+
# Color based on match score (green = high, yellow = medium, red = low)
|
| 140 |
+
score = match_scores[idx] if len(match_scores) > 0 else 0.5
|
| 141 |
+
if score > 0.8:
|
| 142 |
+
color = (0, 255, 0) # Green
|
| 143 |
+
elif score > 0.5:
|
| 144 |
+
color = (255, 255, 0) # Yellow
|
| 145 |
+
else:
|
| 146 |
+
color = (255, 0, 0) # Red
|
| 147 |
+
|
| 148 |
+
# Draw line
|
| 149 |
+
cv2.line(canvas, query_pt, catalog_pt, color, 1)
|
| 150 |
+
|
| 151 |
+
# Draw keypoints
|
| 152 |
+
cv2.circle(canvas, query_pt, 5, (255, 0, 0), -1)
|
| 153 |
+
cv2.circle(canvas, catalog_pt, 5, (0, 0, 255), -1)
|
| 154 |
+
|
| 155 |
+
return Image.fromarray(canvas)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def draw_side_by_side_comparison(
|
| 159 |
+
query_image_path: Path | str,
|
| 160 |
+
catalog_image_path: Path | str,
|
| 161 |
+
max_height: int = 800,
|
| 162 |
+
) -> Image.Image:
|
| 163 |
+
"""Draw query and catalog images side-by-side without keypoints or annotations.
|
| 164 |
+
|
| 165 |
+
This provides a clean visual comparison of the two images without the visual
|
| 166 |
+
clutter of feature matching overlays. Useful for assessing overall visual
|
| 167 |
+
similarity and spotting patterns like spots, scars, or markings.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
query_image_path: Path to query image
|
| 171 |
+
catalog_image_path: Path to catalog/reference image
|
| 172 |
+
max_height: Maximum height for resizing (default: 800)
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
PIL Image with side-by-side images (no annotations)
|
| 176 |
+
"""
|
| 177 |
+
# Load images
|
| 178 |
+
query_img = cv2.imread(str(query_image_path))
|
| 179 |
+
catalog_img = cv2.imread(str(catalog_image_path))
|
| 180 |
+
|
| 181 |
+
# Convert to RGB
|
| 182 |
+
query_rgb = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
|
| 183 |
+
catalog_rgb = cv2.cvtColor(catalog_img, cv2.COLOR_BGR2RGB)
|
| 184 |
+
|
| 185 |
+
# Resize images to same height for side-by-side display
|
| 186 |
+
query_h, query_w = query_rgb.shape[:2]
|
| 187 |
+
catalog_h, catalog_w = catalog_rgb.shape[:2]
|
| 188 |
+
|
| 189 |
+
# Calculate scaling factors
|
| 190 |
+
if query_h > max_height or catalog_h > max_height:
|
| 191 |
+
query_scale = max_height / query_h
|
| 192 |
+
catalog_scale = max_height / catalog_h
|
| 193 |
+
else:
|
| 194 |
+
query_scale = 1.0
|
| 195 |
+
catalog_scale = 1.0
|
| 196 |
+
|
| 197 |
+
# Resize images
|
| 198 |
+
new_query_h = int(query_h * query_scale)
|
| 199 |
+
new_query_w = int(query_w * query_scale)
|
| 200 |
+
new_catalog_h = int(catalog_h * catalog_scale)
|
| 201 |
+
new_catalog_w = int(catalog_w * catalog_scale)
|
| 202 |
+
|
| 203 |
+
query_resized = cv2.resize(query_rgb, (new_query_w, new_query_h))
|
| 204 |
+
catalog_resized = cv2.resize(catalog_rgb, (new_catalog_w, new_catalog_h))
|
| 205 |
+
|
| 206 |
+
# Create side-by-side canvas
|
| 207 |
+
combined_h = max(new_query_h, new_catalog_h)
|
| 208 |
+
combined_w = new_query_w + new_catalog_w
|
| 209 |
+
canvas = np.zeros((combined_h, combined_w, 3), dtype=np.uint8)
|
| 210 |
+
|
| 211 |
+
# Place images on canvas (no keypoints or lines)
|
| 212 |
+
canvas[:new_query_h, :new_query_w] = query_resized
|
| 213 |
+
canvas[:new_catalog_h, new_query_w : new_query_w + new_catalog_w] = catalog_resized
|
| 214 |
+
|
| 215 |
+
return Image.fromarray(canvas)
|