achouffe commited on
Commit
1a0a002
·
verified ·
1 Parent(s): 785cdc2

feat: patch SAM to load on CPU

Browse files
src/snowleopard_reid/pipeline/stages/segmentation.py CHANGED
@@ -70,7 +70,20 @@ def load_sam_predictor(
70
  device_str = get_device(device=device, verbose=True)
71
 
72
  logger.info(f"Loading SAM HQ model: {model_type}")
73
- sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  sam.to(device=device_str)
75
 
76
  predictor = SamPredictor(sam)
 
70
  device_str = get_device(device=device, verbose=True)
71
 
72
  logger.info(f"Loading SAM HQ model: {model_type}")
73
+
74
+ # Patch torch.load to handle CPU-only environments
75
+ # SAM HQ's registry doesn't accept map_location, so we patch temporarily
76
+ import torch
77
+
78
+ original_load = torch.load
79
+ torch.load = lambda f, *args, **kwargs: original_load(
80
+ f, *args, map_location=device_str, **kwargs
81
+ )
82
+ try:
83
+ sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
84
+ finally:
85
+ torch.load = original_load
86
+
87
  sam.to(device=device_str)
88
 
89
  predictor = SamPredictor(sam)