itsbava commited on
Commit
9a7a94e
·
verified ·
1 Parent(s): ff08aee

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +127 -68
handler.py CHANGED
@@ -1,86 +1,145 @@
1
  from typing import Dict, List, Any
2
  import json
3
  import base64
4
- import asyncio
5
- import uvicorn
6
- from threading import Thread
7
- import time
 
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- """Initialize the endpoint handler"""
12
- self.path = path
13
- self.app = None
14
- self.server_thread = None
15
- self._initialize_app()
16
 
17
- def _initialize_app(self):
18
- """Initialize the FastAPI app"""
19
- try:
20
- from app import app
21
- self.app = app
22
- print("FastAPI app loaded successfully")
23
- except Exception as e:
24
- print(f"Error loading app: {e}")
25
- raise
 
 
 
 
 
 
 
 
26
 
27
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28
- """
29
- Handle inference requests from HuggingFace
30
- This method gets called for each request
31
- """
32
  try:
33
- # Start the FastAPI server if not already running
34
- if not self.server_thread or not self.server_thread.is_alive():
35
- self._start_server()
 
 
 
 
 
 
 
 
36
 
37
- # For now, return a simple response
38
- # The actual API calls will go through FastAPI endpoints
39
- return {
40
- "status": "success",
41
- "message": "FastAPI server is running",
42
- "endpoints": [
43
- "/health",
44
- "/extract_embeddings_batch",
45
- "/create_faiss_index",
46
- "/search_faiss"
47
- ],
48
- "server_url": "http://0.0.0.0:8000"
49
- }
50
 
51
  except Exception as e:
52
- return {
53
- "status": "error",
54
- "message": str(e)
55
- }
56
 
57
- def _start_server(self):
58
- """Start the FastAPI server in a background thread"""
59
- def run_server():
 
 
 
 
 
 
 
60
  try:
61
- uvicorn.run(
62
- self.app,
63
- host="0.0.0.0",
64
- port=8000,
65
- log_level="info"
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
- print(f"Server error: {e}")
 
69
 
70
- self.server_thread = Thread(target=run_server, daemon=True)
71
- self.server_thread.start()
72
 
73
- # Give the server time to start
74
- time.sleep(5)
75
- print("FastAPI server started in background thread")
76
-
77
- # Create the handler instance that HuggingFace expects
78
- def get_handler():
79
- return EndpointHandler()
80
-
81
- # For direct testing
82
- if __name__ == "__main__":
83
- handler = EndpointHandler()
84
- # Test the handler
85
- result = handler({"test": "data"})
86
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  import json
3
  import base64
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import insightface
8
+ from PIL import Image
9
+ import io
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
+ self.face_app = None
14
+ self.use_gpu = False
15
+ self._init_model()
 
 
16
 
17
+ def _init_model(self):
18
+ """Initialize InsightFace model"""
19
+ self.use_gpu = torch.cuda.is_available()
20
+
21
+ if self.use_gpu:
22
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
23
+ ctx_id = 0
24
+ else:
25
+ providers = ['CPUExecutionProvider']
26
+ ctx_id = -1
27
+
28
+ self.face_app = insightface.app.FaceAnalysis(
29
+ providers=providers,
30
+ allowed_modules=['detection', 'recognition']
31
+ )
32
+ self.face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
33
+ print(f"Face model loaded: {'GPU' if self.use_gpu else 'CPU'}")
34
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
36
+ """Handle the actual inference request"""
 
 
 
37
  try:
38
+ # Handle health check
39
+ if data.get("inputs") == "test":
40
+ return {
41
+ "status": "healthy",
42
+ "gpu_available": self.use_gpu,
43
+ "model_loaded": self.face_app is not None
44
+ }
45
+
46
+ # Handle batch embedding extraction
47
+ if "images" in data:
48
+ return self._extract_embeddings_batch(data)
49
 
50
+ return {"error": "Unknown request format"}
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  except Exception as e:
53
+ return {"error": str(e)}
 
 
 
54
 
55
+ def _extract_embeddings_batch(self, data):
56
+ """Extract embeddings from batch of images"""
57
+ images = data.get("images", [])
58
+ enhance_quality = data.get("enhance_quality", True)
59
+ aggressive = data.get("aggressive_enhancement", False)
60
+
61
+ embeddings = []
62
+ extraction_info = []
63
+
64
+ for idx, img_b64 in enumerate(images):
65
  try:
66
+ # Decode image
67
+ img_data = base64.b64decode(img_b64)
68
+ img_array = np.frombuffer(img_data, dtype=np.uint8)
69
+ img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
70
+
71
+ if img is None:
72
+ embeddings.append(None)
73
+ extraction_info.append({"error": "Failed to decode", "index": idx})
74
+ continue
75
+
76
+ # Enhance if requested
77
+ if enhance_quality:
78
+ img = self._enhance_image(img, aggressive)
79
+
80
+ # Extract faces
81
+ faces = self.face_app.get(img)
82
+
83
+ if len(faces) == 0:
84
+ embeddings.append(None)
85
+ extraction_info.append({
86
+ "face_count": 0,
87
+ "strategy_used": "gpu_batch" if self.use_gpu else "cpu_batch",
88
+ "enhancement_used": enhance_quality,
89
+ "index": idx
90
+ })
91
+ continue
92
+
93
+ # Get best face
94
+ face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
95
+ embedding = face.embedding / np.linalg.norm(face.embedding)
96
+
97
+ embeddings.append(embedding.tolist())
98
+
99
+ # Calculate metrics
100
+ bbox_area = (face.bbox[2] - face.bbox[0]) * (face.bbox[3] - face.bbox[1])
101
+ img_area = img.shape[0] * img.shape[1]
102
+ confidence = min((bbox_area / img_area) * 2.0, 1.0)
103
+
104
+ extraction_info.append({
105
+ "face_count": len(faces),
106
+ "confidence": float(confidence),
107
+ "strategy_used": "gpu_batch" if self.use_gpu else "cpu_batch",
108
+ "enhancement_used": enhance_quality,
109
+ "quality_score": float(confidence),
110
+ "index": idx
111
+ })
112
+
113
  except Exception as e:
114
+ embeddings.append(None)
115
+ extraction_info.append({"error": str(e), "index": idx})
116
 
117
+ successful = len([e for e in embeddings if e is not None])
 
118
 
119
+ return {
120
+ "embeddings": embeddings,
121
+ "extraction_info": extraction_info,
122
+ "total_processed": len(images),
123
+ "successful": successful,
124
+ "processing_mode": "gpu" if self.use_gpu else "cpu"
125
+ }
126
+
127
+ def _enhance_image(self, img, aggressive=False):
128
+ """Image enhancement logic"""
129
+ try:
130
+ if aggressive:
131
+ img = cv2.bilateralFilter(img, 15, 90, 90)
132
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
133
+ l, a, b = cv2.split(lab)
134
+ clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(8,8))
135
+ l = clahe.apply(l)
136
+ img = cv2.merge([l, a, b])
137
+ img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
138
+ else:
139
+ img = cv2.bilateralFilter(img, 9, 75, 75)
140
+ kernel = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
141
+ img = cv2.filter2D(img, -1, kernel)
142
+
143
+ return img
144
+ except:
145
+ return img