feat: patch SAM to load on CPU
Browse files
src/snowleopard_reid/pipeline/stages/segmentation.py
CHANGED
|
@@ -70,7 +70,20 @@ def load_sam_predictor(
|
|
| 70 |
device_str = get_device(device=device, verbose=True)
|
| 71 |
|
| 72 |
logger.info(f"Loading SAM HQ model: {model_type}")
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
sam.to(device=device_str)
|
| 75 |
|
| 76 |
predictor = SamPredictor(sam)
|
|
|
|
| 70 |
device_str = get_device(device=device, verbose=True)
|
| 71 |
|
| 72 |
logger.info(f"Loading SAM HQ model: {model_type}")
|
| 73 |
+
|
| 74 |
+
# Patch torch.load to handle CPU-only environments
|
| 75 |
+
# SAM HQ's registry doesn't accept map_location, so we patch temporarily
|
| 76 |
+
import torch
|
| 77 |
+
|
| 78 |
+
original_load = torch.load
|
| 79 |
+
torch.load = lambda f, *args, **kwargs: original_load(
|
| 80 |
+
f, *args, map_location=device_str, **kwargs
|
| 81 |
+
)
|
| 82 |
+
try:
|
| 83 |
+
sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
|
| 84 |
+
finally:
|
| 85 |
+
torch.load = original_load
|
| 86 |
+
|
| 87 |
sam.to(device=device_str)
|
| 88 |
|
| 89 |
predictor = SamPredictor(sam)
|