#!/usr/bin/env python3 """Pre-compute pipeline results for all example images. This script runs the full snow leopard identification pipeline on all example images with all available feature extractors, caching the results for instant display in the Gradio app. Usage: # Process all example images with all extractors python scripts/precompute_cache.py # Process specific images python scripts/precompute_cache.py --images IMG_001.jpg IMG_002.jpg # Process with specific extractors only python scripts/precompute_cache.py --extractors sift superpoint # Clear cache and regenerate all python scripts/precompute_cache.py --clear # Show cache summary python scripts/precompute_cache.py --summary """ import argparse import logging import sys import tempfile from pathlib import Path import cv2 import numpy as np import torch from PIL import Image # Add project root to path for imports PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT / "src")) from snowleopard_reid.cache import ( CACHE_DIR, clear_cache, extract_location_body_part_from_filepath, get_cache_dir, get_cache_summary, ) from snowleopard_reid.pipeline.stages import ( run_feature_extraction_stage, run_matching_stage, run_preprocess_stage, run_segmentation_stage, select_best_mask, ) from snowleopard_reid.pipeline.stages.segmentation import ( load_gdino_model, load_sam_predictor, ) from snowleopard_reid.visualization import draw_keypoints_overlay # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Configuration CATALOG_ROOT = PROJECT_ROOT / "data" / "catalog" SAM_CHECKPOINT_DIR = PROJECT_ROOT / "data" / "models" SAM_CHECKPOINT_NAME = "sam_hq_vit_l.pth" EXAMPLES_DIR = PROJECT_ROOT / "data" / "examples" GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base" TEXT_PROMPT = "a snow leopard." SAM_MODEL_TYPE = "vit_l" # Set very high to get ALL matches (will be limited by catalog size) TOP_K_ALL = 1000 # Default top_k for display TOP_K_DEFAULT = 5 # All available extractors ALL_EXTRACTORS = ["sift", "superpoint", "disk", "aliked"] def make_relative_path(filepath: str, project_root: Path) -> str: """Convert absolute path to relative path from project root. Args: filepath: Absolute or relative file path project_root: Project root directory Returns: Relative path string from project root """ try: return str(Path(filepath).relative_to(project_root)) except ValueError: # Already relative or different root return filepath def ensure_sam_model() -> Path: """Download SAM HQ model if not present. Returns: Path to the SAM HQ checkpoint file """ from huggingface_hub import hf_hub_download sam_path = SAM_CHECKPOINT_DIR / SAM_CHECKPOINT_NAME if not sam_path.exists(): logger.info("Downloading SAM HQ model (1.6GB)...") SAM_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) hf_hub_download( repo_id="lkeab/hq-sam", filename=SAM_CHECKPOINT_NAME, local_dir=SAM_CHECKPOINT_DIR, ) logger.info("SAM HQ model downloaded successfully") return sam_path def create_segmentation_viz(image_path: Path, mask: np.ndarray) -> Image.Image: """Create visualization of segmentation mask overlaid on image. Args: image_path: Path to original image mask: Binary segmentation mask Returns: PIL Image with segmentation overlay """ img = cv2.imread(str(image_path)) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Resize mask to match image dimensions if needed if mask.shape[:2] != img_rgb.shape[:2]: mask_resized = cv2.resize( mask.astype(np.uint8), (img_rgb.shape[1], img_rgb.shape[0]), interpolation=cv2.INTER_NEAREST, ) else: mask_resized = mask # Create colored overlay overlay = img_rgb.copy() overlay[mask_resized > 0] = [255, 0, 0] # Red for masked region # Blend alpha = 0.4 blended = cv2.addWeighted(img_rgb, 1 - alpha, overlay, alpha, 0) return Image.fromarray(blended) def process_and_cache( image_path: Path, extractor: str, gdino_processor, gdino_model, sam_predictor, device: str, ) -> bool: """Run full pipeline and cache ALL results for one image/extractor combination. This version caches ALL matches (not just top-k) with location/body_part metadata, and stores NPZ pairwise data for on-demand visualization generation. Args: image_path: Path to example image extractor: Feature extractor to use gdino_processor: Pre-loaded Grounding DINO processor gdino_model: Pre-loaded Grounding DINO model sam_predictor: Pre-loaded SAM HQ predictor device: Device to run on ('cuda' or 'cpu') Returns: True if successful, False otherwise """ logger.info(f"Processing {image_path.name} with {extractor.upper()}...") try: # Create temporary directory for intermediate files with tempfile.TemporaryDirectory(prefix="snowleopard_cache_") as temp_dir: temp_dir = Path(temp_dir) # ================================================================ # Stage 1: Segmentation (GDINO+SAM) # ================================================================ logger.info(" Running GDINO+SAM segmentation...") seg_stage = run_segmentation_stage( image_path=image_path, strategy="gdino_sam", confidence_threshold=0.2, device=device, gdino_processor=gdino_processor, gdino_model=gdino_model, sam_predictor=sam_predictor, text_prompt=TEXT_PROMPT, box_threshold=0.30, text_threshold=0.20, ) predictions = seg_stage["data"]["predictions"] if not predictions: logger.warning(f" No snow leopards detected in {image_path.name}") return False # ================================================================ # Stage 2: Mask Selection # ================================================================ logger.info(" Selecting best mask...") selected_idx, selected_pred = select_best_mask( predictions, strategy="confidence_area", ) # Create segmentation visualization segmentation_image = create_segmentation_viz( image_path=image_path, mask=selected_pred["mask"], ) # ================================================================ # Stage 3: Preprocessing # ================================================================ logger.info(" Preprocessing...") prep_stage = run_preprocess_stage( image_path=image_path, mask=selected_pred["mask"], padding=5, ) cropped_image = prep_stage["data"]["cropped_image"] # Save cropped image for visualization functions cropped_path = temp_dir / "cropped.jpg" cropped_image.save(cropped_path) # ================================================================ # Stage 4: Feature Extraction # ================================================================ logger.info(f" Extracting features ({extractor.upper()})...") feat_stage = run_feature_extraction_stage( image=cropped_image, extractor=extractor, max_keypoints=2048, device=device, ) query_features = feat_stage["data"]["features"] # Create keypoints visualization query_kpts = query_features["keypoints"].cpu().numpy() keypoints_image = draw_keypoints_overlay( image_path=cropped_path, keypoints=query_kpts, max_keypoints=500, color="blue", ps=10, ) # ================================================================ # Stage 5: Matching - Get ALL matches # ================================================================ logger.info(" Matching against catalog (ALL matches)...") temp_pairwise_dir = temp_dir / "pairwise" temp_pairwise_dir.mkdir(exist_ok=True) match_stage = run_matching_stage( query_features=query_features, catalog_path=CATALOG_ROOT, top_k=TOP_K_ALL, # Get ALL matches extractor=extractor, device=device, query_image_path=str(cropped_path), pairwise_output_dir=temp_pairwise_dir, ) matches = match_stage["data"]["matches"] if not matches: logger.warning(f" No matches found for {image_path.name}") return False logger.info(f" Found {len(matches)} matches") # ================================================================ # Enrich matches with location/body_part and convert to relative paths # ================================================================ logger.info(" Adding location/body_part metadata...") for match in matches: location, body_part = extract_location_body_part_from_filepath( match["filepath"] ) match["location"] = location match["body_part"] = body_part # Convert absolute path to relative (portable across environments) match["filepath"] = make_relative_path(match["filepath"], PROJECT_ROOT) # ================================================================ # Set up cache directory # ================================================================ cache_dir = get_cache_dir(image_path, extractor) cache_dir.mkdir(parents=True, exist_ok=True) pairwise_dir = cache_dir / "pairwise" pairwise_dir.mkdir(exist_ok=True) # ================================================================ # Copy NPZ files with catalog_id naming (not rank-based) # ================================================================ logger.info(" Copying NPZ pairwise data...") npz_count = 0 for match in matches: catalog_id = match["catalog_id"] rank = match["rank"] # Source NPZ (rank-based naming from matching stage) src_npz = temp_pairwise_dir / f"rank_{rank:02d}_{catalog_id}.npz" # Destination NPZ (catalog_id naming for cache) dst_npz = pairwise_dir / f"{catalog_id}.npz" if src_npz.exists(): import shutil shutil.copy2(src_npz, dst_npz) npz_count += 1 logger.info(f" Copied {npz_count} NPZ files") # ================================================================ # Build Predictions Dict (v2.0 format with all_matches) # ================================================================ predictions_dict = { "format_version": "2.0", "query_image": str(image_path), "extractor": extractor, "pipeline": { "segmentation": { "strategy": "gdino_sam", "num_predictions": len(predictions), "selected_idx": selected_idx, "confidence": float(selected_pred["confidence"]), }, "preprocessing": { "padding": prep_stage["config"]["padding"], }, "features": { "num_keypoints": int(feat_stage["metrics"]["num_keypoints"]), "extractor": extractor, "max_keypoints": 2048, }, "matching": { "num_catalog_images": match_stage["metrics"]["num_catalog_images"], "num_successful_matches": match_stage["metrics"]["num_successful_matches"], }, }, "all_matches": matches, # ALL matches with location/body_part "top_k": TOP_K_DEFAULT, } # ================================================================ # Save Cache (predictions.json + visualization images) # ================================================================ logger.info(" Saving to cache...") # Save predictions JSON import json predictions_file = cache_dir / "predictions.json" with open(predictions_file, "w") as f: json.dump(predictions_dict, f, indent=2) # Save visualization images segmentation_image.save(cache_dir / "segmentation.png") cropped_image.save(cache_dir / "cropped.png") keypoints_image.save(cache_dir / "keypoints.png") # Log cache size cache_size = sum( f.stat().st_size for f in cache_dir.rglob("*") if f.is_file() ) logger.info( f" Cached: {cache_dir.name} ({cache_size / 1024 / 1024:.2f} MB)" ) logger.info(f" {len(matches)} matches, {npz_count} NPZ files") return True except Exception as e: logger.error(f" Failed: {e}", exc_info=True) return False def main(): parser = argparse.ArgumentParser( description="Pre-compute pipeline results for example images" ) parser.add_argument( "--images", nargs="+", help="Specific image filenames to process (default: all in examples/)", ) parser.add_argument( "--extractors", nargs="+", choices=ALL_EXTRACTORS, default=ALL_EXTRACTORS, help="Feature extractors to use (default: all)", ) parser.add_argument( "--clear", action="store_true", help="Clear all cached results before processing", ) parser.add_argument( "--summary", action="store_true", help="Show cache summary and exit", ) parser.add_argument( "--device", choices=["cpu", "cuda"], default=None, help="Device to run on (default: auto-detect)", ) args = parser.parse_args() # Show summary and exit if args.summary: summary = get_cache_summary() print("\n=== Cache Summary ===") print(f"Total cached: {summary['total_cached']} items") print(f"Total size: {summary['total_size_mb']:.2f} MB") print("\nCached items:") for item in summary["cached_items"]: print(f" - {item['image_stem']} ({item['extractor']}): {item['size_mb']:.2f} MB") return # Clear cache if requested if args.clear: logger.info("Clearing cache...") clear_cache() logger.info("Cache cleared") # Determine device if args.device: device = args.device else: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") if device == "cuda": logger.info(f"GPU: {torch.cuda.get_device_name(0)}") # Find example images if args.images: image_paths = [EXAMPLES_DIR / img for img in args.images] # Filter existing files image_paths = [p for p in image_paths if p.exists()] if not image_paths: logger.error("No valid image paths found") sys.exit(1) else: image_paths = ( list(EXAMPLES_DIR.glob("*.jpg")) + list(EXAMPLES_DIR.glob("*.JPG")) + list(EXAMPLES_DIR.glob("*.png")) ) if not image_paths: logger.error(f"No example images found in {EXAMPLES_DIR}") sys.exit(1) logger.info(f"Found {len(image_paths)} example images") logger.info(f"Extractors: {', '.join(args.extractors)}") # Ensure SAM model is downloaded logger.info("Checking for SAM HQ model...") sam_path = ensure_sam_model() # Load GDINO model once logger.info(f"Loading Grounding DINO model: {GDINO_MODEL_ID}...") gdino_processor, gdino_model = load_gdino_model( model_id=GDINO_MODEL_ID, device=device, ) logger.info("Grounding DINO model loaded") # Load SAM HQ model once logger.info(f"Loading SAM HQ model from {sam_path}...") sam_predictor = load_sam_predictor( checkpoint_path=sam_path, model_type=SAM_MODEL_TYPE, device=device, ) logger.info("SAM HQ model loaded") # Process all combinations total = len(image_paths) * len(args.extractors) success = 0 failed = 0 for i, image_path in enumerate(image_paths): for j, extractor in enumerate(args.extractors): current = i * len(args.extractors) + j + 1 logger.info(f"\n[{current}/{total}] Processing...") if process_and_cache( image_path=image_path, extractor=extractor, gdino_processor=gdino_processor, gdino_model=gdino_model, sam_predictor=sam_predictor, device=device, ): success += 1 else: failed += 1 # Final summary logger.info("\n" + "=" * 50) logger.info("PRECOMPUTATION COMPLETE") logger.info("=" * 50) logger.info(f"Success: {success}/{total}") logger.info(f"Failed: {failed}/{total}") # Show cache summary summary = get_cache_summary() logger.info(f"Total cache size: {summary['total_size_mb']:.2f} MB") if __name__ == "__main__": main()