| | """Cache utilities for precomputed pipeline results. |
| | |
| | This module provides functions for loading and saving cached pipeline results, |
| | enabling instant display of results for example images without running the |
| | expensive pipeline (GDINO+SAM segmentation, feature extraction, matching) on CPU. |
| | |
| | Cache Structure (v2.0 - supports filtering): |
| | cached_results/ |
| | βββ {image_stem}_{extractor}/ |
| | β βββ predictions.json # ALL matches with location/body_part |
| | β βββ segmentation.png # Segmentation visualization |
| | β βββ cropped.png # Cropped snow leopard image |
| | β βββ keypoints.png # Extracted keypoints visualization |
| | β βββ pairwise/ |
| | β βββ {catalog_id}.npz # NPZ data for ALL matches |
| | β βββ ... # (visualizations generated on-demand) |
| | """ |
| |
|
| | import copy |
| | import json |
| | import logging |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | from snowleopard_reid.visualization import ( |
| | draw_matched_keypoints, |
| | draw_side_by_side_comparison, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | PROJECT_ROOT = Path(__file__).parent.parent.parent |
| | CACHE_DIR = PROJECT_ROOT / "cached_results" |
| |
|
| |
|
| | def get_cache_key(image_path: Path | str, extractor: str) -> str: |
| | """Generate cache key from image stem and extractor. |
| | |
| | Args: |
| | image_path: Path to the query image |
| | extractor: Feature extractor name (e.g., 'sift', 'superpoint') |
| | |
| | Returns: |
| | Cache key string in format "{image_stem}_{extractor}" |
| | """ |
| | image_path = Path(image_path) |
| | return f"{image_path.stem}_{extractor}" |
| |
|
| |
|
| | def get_cache_dir(image_path: Path | str, extractor: str) -> Path: |
| | """Get cache directory for an image/extractor combination. |
| | |
| | Args: |
| | image_path: Path to the query image |
| | extractor: Feature extractor name |
| | |
| | Returns: |
| | Path to the cache directory |
| | """ |
| | return CACHE_DIR / get_cache_key(image_path, extractor) |
| |
|
| |
|
| | def is_cached(image_path: Path | str, extractor: str) -> bool: |
| | """Check if results are cached for this image/extractor combination. |
| | |
| | Args: |
| | image_path: Path to the query image |
| | extractor: Feature extractor name |
| | |
| | Returns: |
| | True if all required cache files exist |
| | """ |
| | cache_dir = get_cache_dir(image_path, extractor) |
| | predictions_file = cache_dir / "predictions.json" |
| |
|
| | if not predictions_file.exists(): |
| | return False |
| |
|
| | |
| | required_files = [ |
| | "segmentation.png", |
| | "cropped.png", |
| | "keypoints.png", |
| | ] |
| |
|
| | for filename in required_files: |
| | if not (cache_dir / filename).exists(): |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def load_cached_results(image_path: Path | str, extractor: str) -> dict: |
| | """Load all cached results for an image/extractor combination. |
| | |
| | Args: |
| | image_path: Path to the query image |
| | extractor: Feature extractor name |
| | |
| | Returns: |
| | Dictionary containing: |
| | - predictions: Full pipeline predictions dict |
| | - segmentation_image: PIL Image of segmentation overlay |
| | - cropped_image: PIL Image of cropped snow leopard |
| | - keypoints_image: PIL Image of extracted keypoints |
| | - pairwise_dir: Path to directory with match visualizations |
| | |
| | Raises: |
| | FileNotFoundError: If cache files don't exist |
| | """ |
| | cache_dir = get_cache_dir(image_path, extractor) |
| |
|
| | if not cache_dir.exists(): |
| | raise FileNotFoundError(f"Cache directory not found: {cache_dir}") |
| |
|
| | predictions_file = cache_dir / "predictions.json" |
| | if not predictions_file.exists(): |
| | raise FileNotFoundError(f"Predictions file not found: {predictions_file}") |
| |
|
| | |
| | with open(predictions_file) as f: |
| | predictions = json.load(f) |
| |
|
| | |
| | segmentation_image = Image.open(cache_dir / "segmentation.png") |
| | cropped_image = Image.open(cache_dir / "cropped.png") |
| | keypoints_image = Image.open(cache_dir / "keypoints.png") |
| |
|
| | return { |
| | "predictions": predictions, |
| | "segmentation_image": segmentation_image, |
| | "cropped_image": cropped_image, |
| | "keypoints_image": keypoints_image, |
| | "pairwise_dir": cache_dir / "pairwise", |
| | } |
| |
|
| |
|
| | def load_cached_match_visualizations( |
| | pairwise_dir: Path, |
| | matches: list[dict], |
| | ) -> tuple[dict, dict]: |
| | """Load cached match and clean comparison visualizations. |
| | |
| | Args: |
| | pairwise_dir: Path to pairwise visualizations directory |
| | matches: List of match dictionaries with rank and catalog_id |
| | |
| | Returns: |
| | Tuple of (match_visualizations, clean_comparison_visualizations) |
| | Both are dicts mapping rank -> PIL Image |
| | """ |
| | match_visualizations = {} |
| | clean_comparison_visualizations = {} |
| |
|
| | for match in matches: |
| | rank = match["rank"] |
| | catalog_id = match["catalog_id"] |
| |
|
| | |
| | match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png" |
| | if match_path.exists(): |
| | match_visualizations[rank] = Image.open(match_path) |
| |
|
| | |
| | clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png" |
| | if clean_path.exists(): |
| | clean_comparison_visualizations[rank] = Image.open(clean_path) |
| |
|
| | return match_visualizations, clean_comparison_visualizations |
| |
|
| |
|
| | def save_cache_results( |
| | image_path: Path | str, |
| | extractor: str, |
| | predictions: dict, |
| | segmentation_image: Image.Image, |
| | cropped_image: Image.Image, |
| | keypoints_image: Image.Image, |
| | match_visualizations: dict[int, Image.Image], |
| | clean_comparison_visualizations: dict[int, Image.Image], |
| | matches: list[dict], |
| | ) -> Path: |
| | """Save pipeline results to cache. |
| | |
| | Args: |
| | image_path: Path to the original query image |
| | extractor: Feature extractor name |
| | predictions: Full pipeline predictions dictionary |
| | segmentation_image: PIL Image of segmentation overlay |
| | cropped_image: PIL Image of cropped snow leopard |
| | keypoints_image: PIL Image of extracted keypoints |
| | match_visualizations: Dict mapping rank -> match visualization PIL Image |
| | clean_comparison_visualizations: Dict mapping rank -> clean comparison PIL Image |
| | matches: List of match dictionaries with rank and catalog_id |
| | |
| | Returns: |
| | Path to the cache directory |
| | """ |
| | cache_dir = get_cache_dir(image_path, extractor) |
| | cache_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | predictions_file = cache_dir / "predictions.json" |
| | with open(predictions_file, "w") as f: |
| | json.dump(predictions, f, indent=2) |
| | logger.info(f"Saved predictions: {predictions_file}") |
| |
|
| | |
| | segmentation_image.save(cache_dir / "segmentation.png") |
| | cropped_image.save(cache_dir / "cropped.png") |
| | keypoints_image.save(cache_dir / "keypoints.png") |
| | logger.info(f"Saved visualization images to {cache_dir}") |
| |
|
| | |
| | pairwise_dir = cache_dir / "pairwise" |
| | pairwise_dir.mkdir(exist_ok=True) |
| |
|
| | for match in matches: |
| | rank = match["rank"] |
| | catalog_id = match["catalog_id"] |
| |
|
| | |
| | if rank in match_visualizations: |
| | match_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_match.png" |
| | match_visualizations[rank].save(match_path) |
| |
|
| | |
| | if rank in clean_comparison_visualizations: |
| | clean_path = pairwise_dir / f"rank_{rank:02d}_{catalog_id}_clean.png" |
| | clean_comparison_visualizations[rank].save(clean_path) |
| |
|
| | logger.info(f"Saved {len(match_visualizations)} pairwise visualizations") |
| |
|
| | return cache_dir |
| |
|
| |
|
| | def clear_cache(image_path: Path | str = None, extractor: str = None) -> None: |
| | """Clear cache directory. |
| | |
| | Args: |
| | image_path: If provided, only clear cache for this image |
| | extractor: If provided with image_path, only clear specific cache |
| | """ |
| | import shutil |
| |
|
| | if image_path and extractor: |
| | |
| | cache_dir = get_cache_dir(image_path, extractor) |
| | if cache_dir.exists(): |
| | shutil.rmtree(cache_dir) |
| | logger.info(f"Cleared cache: {cache_dir}") |
| | elif CACHE_DIR.exists(): |
| | |
| | shutil.rmtree(CACHE_DIR) |
| | logger.info(f"Cleared all caches: {CACHE_DIR}") |
| |
|
| |
|
| | def get_cache_summary() -> dict: |
| | """Get summary of cached results. |
| | |
| | Returns: |
| | Dictionary with cache statistics |
| | """ |
| | if not CACHE_DIR.exists(): |
| | return {"total_cached": 0, "total_size_mb": 0, "cached_items": []} |
| |
|
| | cached_items = [] |
| | total_size = 0 |
| |
|
| | for cache_dir in CACHE_DIR.iterdir(): |
| | if cache_dir.is_dir(): |
| | |
| | size = sum(f.stat().st_size for f in cache_dir.rglob("*") if f.is_file()) |
| | total_size += size |
| |
|
| | |
| | parts = cache_dir.name.rsplit("_", 1) |
| | if len(parts) == 2: |
| | image_stem, extractor = parts |
| | else: |
| | image_stem, extractor = cache_dir.name, "unknown" |
| |
|
| | cached_items.append({ |
| | "image_stem": image_stem, |
| | "extractor": extractor, |
| | "size_mb": size / (1024 * 1024), |
| | "path": str(cache_dir), |
| | }) |
| |
|
| | return { |
| | "total_cached": len(cached_items), |
| | "total_size_mb": total_size / (1024 * 1024), |
| | "cached_items": cached_items, |
| | } |
| |
|
| |
|
| | def filter_cached_matches( |
| | all_matches: list[dict], |
| | filter_locations: list[str] | None = None, |
| | filter_body_parts: list[str] | None = None, |
| | top_k: int = 5, |
| | ) -> list[dict]: |
| | """Filter cached matches by location/body_part and return top-k. |
| | |
| | Args: |
| | all_matches: List of all cached match dictionaries |
| | filter_locations: List of locations to filter by (e.g., ["skycrest_valley"]) |
| | filter_body_parts: List of body parts to filter by (e.g., ["head", "right_flank"]) |
| | top_k: Number of top matches to return after filtering |
| | |
| | Returns: |
| | List of filtered and re-ranked match dictionaries |
| | """ |
| | |
| | filtered = [copy.deepcopy(m) for m in all_matches] |
| |
|
| | if filter_locations: |
| | filtered = [m for m in filtered if m.get("location") in filter_locations] |
| |
|
| | if filter_body_parts: |
| | filtered = [m for m in filtered if m.get("body_part") in filter_body_parts] |
| |
|
| | |
| | filtered = sorted(filtered, key=lambda x: x.get("wasserstein", 0), reverse=True) |
| |
|
| | |
| | for i, match in enumerate(filtered[:top_k]): |
| | match["rank"] = i + 1 |
| |
|
| | return filtered[:top_k] |
| |
|
| |
|
| | def generate_visualizations_from_npz( |
| | pairwise_dir: Path, |
| | matches: list[dict], |
| | cropped_image_path: Path | str, |
| | ) -> tuple[dict, dict]: |
| | """Generate match visualizations on-demand from cached NPZ data. |
| | |
| | Args: |
| | pairwise_dir: Path to directory containing NPZ pairwise data files |
| | matches: List of filtered match dictionaries with catalog_id and filepath |
| | cropped_image_path: Path to the cropped query image |
| | |
| | Returns: |
| | Tuple of (match_visualizations, clean_comparison_visualizations) |
| | Both are dicts mapping rank -> PIL Image |
| | """ |
| | match_visualizations = {} |
| | clean_comparison_visualizations = {} |
| |
|
| | cropped_image_path = Path(cropped_image_path) |
| |
|
| | for match in matches: |
| | rank = match["rank"] |
| | catalog_id = match["catalog_id"] |
| |
|
| | |
| | filepath = match["filepath"] |
| | if not Path(filepath).is_absolute(): |
| | catalog_image_path = PROJECT_ROOT / filepath |
| | else: |
| | catalog_image_path = Path(filepath) |
| |
|
| | |
| | npz_path = pairwise_dir / f"{catalog_id}.npz" |
| |
|
| | if npz_path.exists(): |
| | try: |
| | pairwise_data = np.load(npz_path, allow_pickle=True) |
| |
|
| | |
| | match_viz = draw_matched_keypoints( |
| | query_image_path=cropped_image_path, |
| | catalog_image_path=catalog_image_path, |
| | query_keypoints=pairwise_data["query_keypoints"], |
| | catalog_keypoints=pairwise_data["catalog_keypoints"], |
| | match_scores=pairwise_data["match_scores"], |
| | max_matches=100, |
| | ) |
| | match_visualizations[rank] = match_viz |
| |
|
| | |
| | clean_viz = draw_side_by_side_comparison( |
| | query_image_path=cropped_image_path, |
| | catalog_image_path=catalog_image_path, |
| | ) |
| | clean_comparison_visualizations[rank] = clean_viz |
| |
|
| | except Exception as e: |
| | logger.warning( |
| | f"Failed to generate visualization for {catalog_id}: {e}" |
| | ) |
| | else: |
| | logger.warning(f"NPZ file not found for {catalog_id}: {npz_path}") |
| |
|
| | return match_visualizations, clean_comparison_visualizations |
| |
|
| |
|
| | def extract_location_body_part_from_filepath(filepath: str) -> tuple[str, str]: |
| | """Extract location and body_part from catalog image filepath. |
| | |
| | Expected filepath format: |
| | .../database/{location}/{individual}/images/{body_part}/{filename} |
| | |
| | Args: |
| | filepath: Path to catalog image |
| | |
| | Returns: |
| | Tuple of (location, body_part) |
| | """ |
| | parts = Path(filepath).parts |
| |
|
| | |
| | try: |
| | db_idx = parts.index("database") |
| | location = parts[db_idx + 1] if db_idx + 1 < len(parts) else "unknown" |
| |
|
| | |
| | img_idx = parts.index("images") |
| | body_part = parts[img_idx + 1] if img_idx + 1 < len(parts) else "unknown" |
| |
|
| | return location, body_part |
| | except (ValueError, IndexError): |
| | return "unknown", "unknown" |
| |
|