Yurikks commited on
Commit
4601f86
·
verified ·
1 Parent(s): e23420c

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +363 -358
  2. tts_service.py +109 -73
main.py CHANGED
@@ -1,358 +1,363 @@
1
- """
2
- TTS Backend for YorubaApp
3
- Uses facebook/mms-tts-yor model for Yoruba text-to-speech
4
-
5
- Security Features:
6
- - Firebase ID Token validation
7
- - Rate limiting per user (100 requests/day)
8
- - Request signature validation
9
- - API key fallback for development
10
- """
11
-
12
- import os
13
- import json
14
- import hmac
15
- import hashlib
16
- from datetime import datetime, timezone
17
- from typing import Optional
18
- from fastapi import FastAPI, HTTPException, Header, Depends
19
- from fastapi.middleware.cors import CORSMiddleware
20
- from pydantic import BaseModel
21
- import base64
22
- import logging
23
-
24
- import firebase_admin
25
- from firebase_admin import auth as firebase_auth, credentials
26
-
27
- from tts_service import TTSService
28
- from cache import TTSCache
29
-
30
- # Configure logging
31
- logging.basicConfig(level=logging.INFO)
32
- logger = logging.getLogger(__name__)
33
-
34
- # =============================================================================
35
- # CONFIGURATION
36
- # =============================================================================
37
-
38
- # API Key for development (fallback)
39
- API_KEY = os.environ.get("TTS_API_KEY", "")
40
-
41
- # Firebase configuration (set in HF Spaces secrets)
42
- FIREBASE_PROJECT_ID = os.environ.get("FIREBASE_PROJECT_ID", "demo-yorubaapp")
43
- FIREBASE_SERVICE_ACCOUNT = os.environ.get("FIREBASE_SERVICE_ACCOUNT_JSON", "")
44
-
45
- # Rate limiting
46
- MAX_REQUESTS_PER_DAY = 100
47
-
48
- # Request signing secret (for additional verification)
49
- REQUEST_SIGNING_SECRET = os.environ.get("REQUEST_SIGNING_SECRET", "")
50
-
51
- # =============================================================================
52
- # FIREBASE INITIALIZATION
53
- # =============================================================================
54
-
55
- firebase_initialized = False
56
-
57
- def initialize_firebase():
58
- global firebase_initialized
59
- if firebase_initialized:
60
- return
61
-
62
- try:
63
- if FIREBASE_SERVICE_ACCOUNT:
64
- # Parse JSON from environment variable
65
- cred_dict = json.loads(FIREBASE_SERVICE_ACCOUNT)
66
- cred = credentials.Certificate(cred_dict)
67
- firebase_admin.initialize_app(cred)
68
- logger.info("Firebase Admin SDK initialized with service account")
69
- else:
70
- # Initialize with just project ID (limited functionality)
71
- firebase_admin.initialize_app(options={
72
- 'projectId': FIREBASE_PROJECT_ID
73
- })
74
- logger.warning("Firebase Admin SDK initialized without service account (limited functionality)")
75
-
76
- firebase_initialized = True
77
- except Exception as e:
78
- logger.error(f"Failed to initialize Firebase Admin SDK: {e}")
79
- # Continue without Firebase - fall back to API key only
80
-
81
- # Initialize on startup
82
- initialize_firebase()
83
-
84
- # =============================================================================
85
- # RATE LIMITING (In-Memory - resets on restart)
86
- # =============================================================================
87
-
88
- # In production, this should use Redis or Firestore
89
- rate_limit_cache: dict[str, dict] = {}
90
-
91
- def check_rate_limit(user_id: str) -> tuple[bool, int]:
92
- """
93
- Check if user has exceeded rate limit.
94
- Returns (allowed, remaining_requests)
95
- """
96
- today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
97
- cache_key = f"{user_id}_{today}"
98
-
99
- if cache_key not in rate_limit_cache:
100
- rate_limit_cache[cache_key] = {"count": 0, "date": today}
101
-
102
- entry = rate_limit_cache[cache_key]
103
-
104
- # Reset if new day
105
- if entry["date"] != today:
106
- entry = {"count": 0, "date": today}
107
- rate_limit_cache[cache_key] = entry
108
-
109
- remaining = MAX_REQUESTS_PER_DAY - entry["count"]
110
-
111
- if entry["count"] >= MAX_REQUESTS_PER_DAY:
112
- return False, 0
113
-
114
- # Increment count
115
- entry["count"] += 1
116
- return True, remaining - 1
117
-
118
- def cleanup_old_rate_limits():
119
- """Remove entries from previous days"""
120
- today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
121
- keys_to_remove = [k for k, v in rate_limit_cache.items() if v.get("date") != today]
122
- for key in keys_to_remove:
123
- del rate_limit_cache[key]
124
-
125
- # =============================================================================
126
- # SECURITY HELPERS
127
- # =============================================================================
128
-
129
- async def verify_firebase_token(authorization: Optional[str]) -> Optional[dict]:
130
- """
131
- Verify Firebase ID token and return user info.
132
- Returns None if verification fails.
133
- """
134
- if not authorization or not authorization.startswith("Bearer "):
135
- return None
136
-
137
- token = authorization[7:] # Remove "Bearer " prefix
138
-
139
- try:
140
- decoded_token = firebase_auth.verify_id_token(token)
141
- return {
142
- "uid": decoded_token["uid"],
143
- "email": decoded_token.get("email"),
144
- "email_verified": decoded_token.get("email_verified", False)
145
- }
146
- except Exception as e:
147
- logger.warning(f"Firebase token verification failed: {e}")
148
- return None
149
-
150
- def verify_request_signature(
151
- user_id: str,
152
- text: str,
153
- timestamp: str,
154
- signature: str
155
- ) -> bool:
156
- """
157
- Verify HMAC signature of request.
158
- Signature = HMAC-SHA256(userId + timestamp + text)
159
- """
160
- if not REQUEST_SIGNING_SECRET:
161
- return True # Skip if not configured
162
-
163
- # Check timestamp (within 5 minutes)
164
- try:
165
- request_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
166
- now = datetime.now(timezone.utc)
167
- diff = abs((now - request_time).total_seconds())
168
- if diff > 300: # 5 minutes
169
- logger.warning(f"Request timestamp too old: {diff} seconds")
170
- return False
171
- except Exception as e:
172
- logger.warning(f"Invalid timestamp format: {e}")
173
- return False
174
-
175
- # Verify signature
176
- message = f"{user_id}{timestamp}{text}"
177
- expected_signature = hmac.new(
178
- REQUEST_SIGNING_SECRET.encode(),
179
- message.encode(),
180
- hashlib.sha256
181
- ).hexdigest()
182
-
183
- return hmac.compare_digest(signature, expected_signature)
184
-
185
- # =============================================================================
186
- # FASTAPI APP
187
- # =============================================================================
188
-
189
- app = FastAPI(
190
- title="YorubaApp TTS API",
191
- description="Text-to-Speech API for Yoruba language using MMS-TTS-YOR",
192
- version="2.0.0"
193
- )
194
-
195
- # CORS - allow requests from Expo dev server and production
196
- app.add_middleware(
197
- CORSMiddleware,
198
- allow_origins=["*"], # Configure for production
199
- allow_credentials=True,
200
- allow_methods=["*"],
201
- allow_headers=["*"],
202
- )
203
-
204
- # Initialize services
205
- tts = TTSService()
206
- cache = TTSCache()
207
-
208
- # =============================================================================
209
- # MODELS
210
- # =============================================================================
211
-
212
- class TTSRequest(BaseModel):
213
- text: str
214
- timestamp: Optional[str] = None # ISO format for signature verification
215
- signature: Optional[str] = None # HMAC signature
216
-
217
- class TTSResponse(BaseModel):
218
- audio: str # base64 encoded WAV
219
- cached: bool
220
- remaining_requests: Optional[int] = None
221
-
222
- # =============================================================================
223
- # ENDPOINTS
224
- # =============================================================================
225
-
226
- @app.get("/")
227
- async def root():
228
- return {"status": "ok", "service": "YorubaApp TTS API", "version": "2.0.0"}
229
-
230
- @app.get("/health")
231
- async def health():
232
- return {
233
- "status": "healthy",
234
- "model": "facebook/mms-tts-yor",
235
- "firebase_initialized": firebase_initialized
236
- }
237
-
238
- @app.post("/tts", response_model=TTSResponse)
239
- async def text_to_speech(
240
- request: TTSRequest,
241
- authorization: Optional[str] = Header(None),
242
- x_api_key: Optional[str] = Header(None, alias="X-API-Key")
243
- ):
244
- """
245
- Generate speech from text.
246
-
247
- Authentication (in order of priority):
248
- 1. Firebase ID Token (Authorization: Bearer <token>)
249
- 2. API Key (X-API-Key header) - for development only
250
-
251
- Rate limiting: 100 requests per user per day
252
- """
253
- user_info = None
254
- user_id = None
255
-
256
- # Try Firebase token first
257
- if authorization:
258
- user_info = await verify_firebase_token(authorization)
259
- if user_info:
260
- user_id = user_info["uid"]
261
- logger.info(f"Authenticated via Firebase: {user_id[:8]}...")
262
-
263
- # Fall back to API key
264
- if not user_info:
265
- if API_KEY and x_api_key == API_KEY:
266
- user_id = "api_key_user"
267
- logger.info("Authenticated via API key")
268
- else:
269
- raise HTTPException(status_code=401, detail="Invalid or missing authentication")
270
-
271
- # Validate request
272
- text = request.text.strip()
273
- if not text:
274
- raise HTTPException(status_code=400, detail="Text is required")
275
-
276
- if len(text) > 500:
277
- raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
278
-
279
- # Verify request signature (optional extra security)
280
- if request.timestamp and request.signature and user_id:
281
- if not verify_request_signature(user_id, text, request.timestamp, request.signature):
282
- raise HTTPException(status_code=401, detail="Invalid request signature")
283
-
284
- # Check rate limit
285
- allowed, remaining = check_rate_limit(user_id)
286
- if not allowed:
287
- raise HTTPException(
288
- status_code=429,
289
- detail="Daily rate limit exceeded. Please try again tomorrow."
290
- )
291
-
292
- logger.info(f"TTS request from {user_id[:8]}... for text: {text[:50]}...")
293
-
294
- # Check cache first
295
- cached_audio = await cache.get(text)
296
- if cached_audio:
297
- logger.info("Returning cached audio")
298
- return TTSResponse(audio=cached_audio, cached=True, remaining_requests=remaining)
299
-
300
- try:
301
- # Generate audio
302
- audio_bytes = await tts.synthesize(text)
303
- audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
304
-
305
- # Cache result
306
- await cache.set(text, audio_b64)
307
-
308
- logger.info(f"Generated audio: {len(audio_bytes)} bytes")
309
- return TTSResponse(audio=audio_b64, cached=False, remaining_requests=remaining)
310
-
311
- except Exception as e:
312
- logger.error(f"TTS synthesis failed: {e}")
313
- raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}")
314
-
315
- @app.get("/rate-limit/{user_id}")
316
- async def get_rate_limit_status(
317
- user_id: str,
318
- authorization: Optional[str] = Header(None)
319
- ):
320
- """
321
- Get current rate limit status for a user.
322
- Only accessible with valid Firebase token for the same user.
323
- """
324
- user_info = await verify_firebase_token(authorization)
325
-
326
- if not user_info or user_info["uid"] != user_id:
327
- raise HTTPException(status_code=401, detail="Unauthorized")
328
-
329
- today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
330
- cache_key = f"{user_id}_{today}"
331
-
332
- if cache_key in rate_limit_cache:
333
- count = rate_limit_cache[cache_key]["count"]
334
- else:
335
- count = 0
336
-
337
- return {
338
- "user_id": user_id,
339
- "date": today,
340
- "used": count,
341
- "limit": MAX_REQUESTS_PER_DAY,
342
- "remaining": max(0, MAX_REQUESTS_PER_DAY - count)
343
- }
344
-
345
- # =============================================================================
346
- # STARTUP
347
- # =============================================================================
348
-
349
- @app.on_event("startup")
350
- async def startup_event():
351
- """Cleanup old rate limit entries on startup"""
352
- cleanup_old_rate_limits()
353
- logger.info("TTS API started")
354
-
355
- if __name__ == "__main__":
356
- import uvicorn
357
- # Port 7860 is the default for Hugging Face Spaces
358
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
+ """
2
+ TTS Backend for YorubaApp
3
+ Uses facebook/mms-tts-yor model for Yoruba text-to-speech
4
+
5
+ Security Features:
6
+ - Firebase ID Token validation
7
+ - Rate limiting per user (100 requests/day)
8
+ - Request signature validation
9
+ - API key fallback for development
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import hmac
15
+ import hashlib
16
+ from datetime import datetime, timezone
17
+ from typing import Optional
18
+ from fastapi import FastAPI, HTTPException, Header, Depends
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from pydantic import BaseModel
21
+ import base64
22
+ import logging
23
+
24
+ import firebase_admin
25
+ from firebase_admin import auth as firebase_auth, credentials
26
+
27
+ from tts_service import TTSService
28
+ from cache import TTSCache
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # =============================================================================
35
+ # CONFIGURATION
36
+ # =============================================================================
37
+
38
+ # API Key for development (fallback)
39
+ API_KEY = os.environ.get("TTS_API_KEY", "")
40
+
41
+ # Firebase configuration (set in HF Spaces secrets)
42
+ FIREBASE_PROJECT_ID = os.environ.get("FIREBASE_PROJECT_ID", "demo-yorubaapp")
43
+ FIREBASE_SERVICE_ACCOUNT = os.environ.get("FIREBASE_SERVICE_ACCOUNT_JSON", "")
44
+
45
+ # Rate limiting
46
+ MAX_REQUESTS_PER_DAY = 100
47
+
48
+ # Request signing secret (for additional verification)
49
+ REQUEST_SIGNING_SECRET = os.environ.get("REQUEST_SIGNING_SECRET", "")
50
+
51
+ # =============================================================================
52
+ # FIREBASE INITIALIZATION
53
+ # =============================================================================
54
+
55
+ firebase_initialized = False
56
+
57
+ def initialize_firebase():
58
+ global firebase_initialized
59
+ if firebase_initialized:
60
+ return
61
+
62
+ try:
63
+ if FIREBASE_SERVICE_ACCOUNT:
64
+ # Parse JSON from environment variable
65
+ cred_dict = json.loads(FIREBASE_SERVICE_ACCOUNT)
66
+ cred = credentials.Certificate(cred_dict)
67
+ firebase_admin.initialize_app(cred)
68
+ logger.info("Firebase Admin SDK initialized with service account")
69
+ else:
70
+ # Initialize with just project ID (limited functionality)
71
+ firebase_admin.initialize_app(options={
72
+ 'projectId': FIREBASE_PROJECT_ID
73
+ })
74
+ logger.warning("Firebase Admin SDK initialized without service account (limited functionality)")
75
+
76
+ firebase_initialized = True
77
+ except Exception as e:
78
+ logger.error(f"Failed to initialize Firebase Admin SDK: {e}")
79
+ # Continue without Firebase - fall back to API key only
80
+
81
+ # Initialize on startup
82
+ initialize_firebase()
83
+
84
+ # =============================================================================
85
+ # RATE LIMITING (In-Memory - resets on restart)
86
+ # =============================================================================
87
+
88
+ # In production, this should use Redis or Firestore
89
+ rate_limit_cache: dict[str, dict] = {}
90
+
91
+ def check_rate_limit(user_id: str) -> tuple[bool, int]:
92
+ """
93
+ Check if user has exceeded rate limit.
94
+ Returns (allowed, remaining_requests)
95
+ """
96
+ today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
97
+ cache_key = f"{user_id}_{today}"
98
+
99
+ if cache_key not in rate_limit_cache:
100
+ rate_limit_cache[cache_key] = {"count": 0, "date": today}
101
+
102
+ entry = rate_limit_cache[cache_key]
103
+
104
+ # Reset if new day
105
+ if entry["date"] != today:
106
+ entry = {"count": 0, "date": today}
107
+ rate_limit_cache[cache_key] = entry
108
+
109
+ remaining = MAX_REQUESTS_PER_DAY - entry["count"]
110
+
111
+ if entry["count"] >= MAX_REQUESTS_PER_DAY:
112
+ return False, 0
113
+
114
+ # Increment count
115
+ entry["count"] += 1
116
+ return True, remaining - 1
117
+
118
+ def cleanup_old_rate_limits():
119
+ """Remove entries from previous days"""
120
+ today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
121
+ keys_to_remove = [k for k, v in rate_limit_cache.items() if v.get("date") != today]
122
+ for key in keys_to_remove:
123
+ del rate_limit_cache[key]
124
+
125
+ # =============================================================================
126
+ # SECURITY HELPERS
127
+ # =============================================================================
128
+
129
+ async def verify_firebase_token(authorization: Optional[str]) -> Optional[dict]:
130
+ """
131
+ Verify Firebase ID token and return user info.
132
+ Returns None if verification fails.
133
+ """
134
+ if not authorization or not authorization.startswith("Bearer "):
135
+ return None
136
+
137
+ token = authorization[7:] # Remove "Bearer " prefix
138
+
139
+ try:
140
+ decoded_token = firebase_auth.verify_id_token(token)
141
+ return {
142
+ "uid": decoded_token["uid"],
143
+ "email": decoded_token.get("email"),
144
+ "email_verified": decoded_token.get("email_verified", False)
145
+ }
146
+ except Exception as e:
147
+ logger.warning(f"Firebase token verification failed: {e}")
148
+ return None
149
+
150
+ def verify_request_signature(
151
+ user_id: str,
152
+ text: str,
153
+ timestamp: str,
154
+ signature: str
155
+ ) -> bool:
156
+ """
157
+ Verify HMAC signature of request.
158
+ Signature = HMAC-SHA256(userId + timestamp + text)
159
+ """
160
+ if not REQUEST_SIGNING_SECRET:
161
+ return True # Skip if not configured
162
+
163
+ # Check timestamp (within 5 minutes)
164
+ try:
165
+ request_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
166
+ now = datetime.now(timezone.utc)
167
+ diff = abs((now - request_time).total_seconds())
168
+ if diff > 300: # 5 minutes
169
+ logger.warning(f"Request timestamp too old: {diff} seconds")
170
+ return False
171
+ except Exception as e:
172
+ logger.warning(f"Invalid timestamp format: {e}")
173
+ return False
174
+
175
+ # Verify signature
176
+ message = f"{user_id}{timestamp}{text}"
177
+ expected_signature = hmac.new(
178
+ REQUEST_SIGNING_SECRET.encode(),
179
+ message.encode(),
180
+ hashlib.sha256
181
+ ).hexdigest()
182
+
183
+ return hmac.compare_digest(signature, expected_signature)
184
+
185
+ # =============================================================================
186
+ # FASTAPI APP
187
+ # =============================================================================
188
+
189
+ app = FastAPI(
190
+ title="YorubaApp TTS API",
191
+ description="Text-to-Speech API for Yoruba language using MMS-TTS-YOR",
192
+ version="2.0.0"
193
+ )
194
+
195
+ # CORS - allow requests from Expo dev server and production
196
+ app.add_middleware(
197
+ CORSMiddleware,
198
+ allow_origins=["*"], # Configure for production
199
+ allow_credentials=True,
200
+ allow_methods=["*"],
201
+ allow_headers=["*"],
202
+ )
203
+
204
+ # Initialize services
205
+ tts = TTSService()
206
+ cache = TTSCache()
207
+
208
+ # =============================================================================
209
+ # MODELS
210
+ # =============================================================================
211
+
212
+ class TTSRequest(BaseModel):
213
+ text: str
214
+ speed: Optional[float] = 1.0 # Speed: 0.5-1.5 (1.0 = normal, 0.7 = devagar)
215
+ timestamp: Optional[str] = None # ISO format for signature verification
216
+ signature: Optional[str] = None # HMAC signature
217
+
218
+ class TTSResponse(BaseModel):
219
+ audio: str # base64 encoded WAV
220
+ cached: bool
221
+ remaining_requests: Optional[int] = None
222
+
223
+ # =============================================================================
224
+ # ENDPOINTS
225
+ # =============================================================================
226
+
227
+ @app.get("/")
228
+ async def root():
229
+ return {"status": "ok", "service": "YorubaApp TTS API", "version": "2.0.0"}
230
+
231
+ @app.get("/health")
232
+ async def health():
233
+ return {
234
+ "status": "healthy",
235
+ "model": "facebook/mms-tts-yor",
236
+ "firebase_initialized": firebase_initialized
237
+ }
238
+
239
+ @app.post("/tts", response_model=TTSResponse)
240
+ async def text_to_speech(
241
+ request: TTSRequest,
242
+ authorization: Optional[str] = Header(None),
243
+ x_api_key: Optional[str] = Header(None, alias="X-API-Key")
244
+ ):
245
+ """
246
+ Generate speech from text.
247
+
248
+ Authentication (in order of priority):
249
+ 1. Firebase ID Token (Authorization: Bearer <token>)
250
+ 2. API Key (X-API-Key header) - for development only
251
+
252
+ Rate limiting: 100 requests per user per day
253
+ """
254
+ user_info = None
255
+ user_id = None
256
+
257
+ # Try Firebase token first
258
+ if authorization:
259
+ user_info = await verify_firebase_token(authorization)
260
+ if user_info:
261
+ user_id = user_info["uid"]
262
+ logger.info(f"Authenticated via Firebase: {user_id[:8]}...")
263
+
264
+ # Fall back to API key
265
+ if not user_info:
266
+ if API_KEY and x_api_key == API_KEY:
267
+ user_id = "api_key_user"
268
+ logger.info("Authenticated via API key")
269
+ else:
270
+ raise HTTPException(status_code=401, detail="Invalid or missing authentication")
271
+
272
+ # Validate request
273
+ text = request.text.strip()
274
+ if not text:
275
+ raise HTTPException(status_code=400, detail="Text is required")
276
+
277
+ if len(text) > 500:
278
+ raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
279
+
280
+ # Verify request signature (optional extra security)
281
+ if request.timestamp and request.signature and user_id:
282
+ if not verify_request_signature(user_id, text, request.timestamp, request.signature):
283
+ raise HTTPException(status_code=401, detail="Invalid request signature")
284
+
285
+ # Check rate limit
286
+ allowed, remaining = check_rate_limit(user_id)
287
+ if not allowed:
288
+ raise HTTPException(
289
+ status_code=429,
290
+ detail="Daily rate limit exceeded. Please try again tomorrow."
291
+ )
292
+
293
+ logger.info(f"TTS request from {user_id[:8]}... for text: {text[:50]}... speed: {request.speed}")
294
+
295
+ # Normalize speed (clamp to safe range)
296
+ speed = max(0.5, min(1.5, request.speed or 1.0))
297
+
298
+ # Check cache first (include speed in cache key)
299
+ cache_key = f"{text}|speed={speed}" if speed != 1.0 else text
300
+ cached_audio = await cache.get(cache_key)
301
+ if cached_audio:
302
+ logger.info("Returning cached audio")
303
+ return TTSResponse(audio=cached_audio, cached=True, remaining_requests=remaining)
304
+
305
+ try:
306
+ # Generate audio with speed
307
+ audio_bytes = await tts.synthesize(text, speed=speed)
308
+ audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
309
+
310
+ # Cache result
311
+ await cache.set(cache_key, audio_b64)
312
+
313
+ logger.info(f"Generated audio: {len(audio_bytes)} bytes")
314
+ return TTSResponse(audio=audio_b64, cached=False, remaining_requests=remaining)
315
+
316
+ except Exception as e:
317
+ logger.error(f"TTS synthesis failed: {e}")
318
+ raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}")
319
+
320
+ @app.get("/rate-limit/{user_id}")
321
+ async def get_rate_limit_status(
322
+ user_id: str,
323
+ authorization: Optional[str] = Header(None)
324
+ ):
325
+ """
326
+ Get current rate limit status for a user.
327
+ Only accessible with valid Firebase token for the same user.
328
+ """
329
+ user_info = await verify_firebase_token(authorization)
330
+
331
+ if not user_info or user_info["uid"] != user_id:
332
+ raise HTTPException(status_code=401, detail="Unauthorized")
333
+
334
+ today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
335
+ cache_key = f"{user_id}_{today}"
336
+
337
+ if cache_key in rate_limit_cache:
338
+ count = rate_limit_cache[cache_key]["count"]
339
+ else:
340
+ count = 0
341
+
342
+ return {
343
+ "user_id": user_id,
344
+ "date": today,
345
+ "used": count,
346
+ "limit": MAX_REQUESTS_PER_DAY,
347
+ "remaining": max(0, MAX_REQUESTS_PER_DAY - count)
348
+ }
349
+
350
+ # =============================================================================
351
+ # STARTUP
352
+ # =============================================================================
353
+
354
+ @app.on_event("startup")
355
+ async def startup_event():
356
+ """Cleanup old rate limit entries on startup"""
357
+ cleanup_old_rate_limits()
358
+ logger.info("TTS API started")
359
+
360
+ if __name__ == "__main__":
361
+ import uvicorn
362
+ # Port 7860 is the default for Hugging Face Spaces
363
+ uvicorn.run(app, host="0.0.0.0", port=7860)
tts_service.py CHANGED
@@ -1,73 +1,109 @@
1
- """
2
- TTS Service using facebook/mms-tts-yor (Yoruba)
3
- """
4
-
5
- import io
6
- import logging
7
- import asyncio
8
- from functools import lru_cache
9
-
10
- import torch
11
- import numpy as np
12
- import scipy.io.wavfile as wavfile
13
- from transformers import VitsModel, AutoTokenizer
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class TTSService:
19
- def __init__(self):
20
- logger.info("Loading MMS-TTS-YOR model...")
21
-
22
- # Load model and tokenizer
23
- self.model = VitsModel.from_pretrained("facebook/mms-tts-yor")
24
- self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor")
25
-
26
- # Set to evaluation mode
27
- self.model.eval()
28
-
29
- # Use GPU if available
30
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
- self.model = self.model.to(self.device)
32
-
33
- logger.info(f"Model loaded on {self.device}")
34
- logger.info(f"Sampling rate: {self.model.config.sampling_rate}")
35
-
36
- async def synthesize(self, text: str) -> bytes:
37
- """
38
- Synthesize speech from Yoruba text.
39
- Returns WAV audio bytes.
40
- """
41
- # Run synthesis in thread pool to avoid blocking
42
- loop = asyncio.get_event_loop()
43
- return await loop.run_in_executor(None, self._synthesize_sync, text)
44
-
45
- def _synthesize_sync(self, text: str) -> bytes:
46
- """Synchronous synthesis (runs in thread pool)"""
47
-
48
- # Tokenize input
49
- inputs = self.tokenizer(text, return_tensors="pt")
50
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
51
-
52
- # Generate audio
53
- with torch.no_grad():
54
- output = self.model(**inputs).waveform
55
-
56
- # Convert to numpy
57
- waveform = output.squeeze().cpu().numpy()
58
-
59
- # Normalize to 16-bit PCM
60
- waveform = np.clip(waveform, -1.0, 1.0)
61
- waveform_int16 = (waveform * 32767).astype(np.int16)
62
-
63
- # Write to WAV buffer
64
- buffer = io.BytesIO()
65
- wavfile.write(buffer, rate=self.model.config.sampling_rate, data=waveform_int16)
66
-
67
- return buffer.getvalue()
68
-
69
-
70
- # Singleton instance
71
- @lru_cache(maxsize=1)
72
- def get_tts_service() -> TTSService:
73
- return TTSService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTS Service using facebook/mms-tts-yor (Yoruba)
3
+ Supports variable speed playback (normal and slow)
4
+ """
5
+
6
+ import io
7
+ import logging
8
+ import asyncio
9
+ from functools import lru_cache
10
+
11
+ import torch
12
+ import numpy as np
13
+ import scipy.io.wavfile as wavfile
14
+ from transformers import VitsModel, AutoTokenizer
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TTSService:
20
+ def __init__(self):
21
+ logger.info("Loading MMS-TTS-YOR model...")
22
+
23
+ # Load model and tokenizer
24
+ self.model = VitsModel.from_pretrained("facebook/mms-tts-yor")
25
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor")
26
+
27
+ # Set to evaluation mode
28
+ self.model.eval()
29
+
30
+ # Use GPU if available
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model = self.model.to(self.device)
33
+
34
+ self.sample_rate = self.model.config.sampling_rate
35
+
36
+ logger.info(f"Model loaded on {self.device}")
37
+ logger.info(f"Sampling rate: {self.sample_rate}")
38
+
39
+ async def synthesize(self, text: str, speed: float = 1.0) -> bytes:
40
+ """
41
+ Synthesize speech from Yoruba text.
42
+
43
+ Args:
44
+ text: Text to synthesize
45
+ speed: Playback speed (0.5 = half speed, 1.0 = normal, 1.5 = faster)
46
+
47
+ Returns WAV audio bytes.
48
+ """
49
+ # Run synthesis in thread pool to avoid blocking
50
+ loop = asyncio.get_event_loop()
51
+ return await loop.run_in_executor(None, self._synthesize_sync, text, speed)
52
+
53
+ def _synthesize_sync(self, text: str, speed: float = 1.0) -> bytes:
54
+ """Synchronous synthesis (runs in thread pool)"""
55
+
56
+ # Tokenize input
57
+ inputs = self.tokenizer(text, return_tensors="pt")
58
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
59
+
60
+ # Generate audio
61
+ with torch.no_grad():
62
+ output = self.model(**inputs).waveform
63
+
64
+ # Convert to numpy
65
+ waveform = output.squeeze().cpu().numpy()
66
+
67
+ # Apply time-stretching for speed change (using simple resampling)
68
+ if speed != 1.0 and speed > 0:
69
+ waveform = self._change_speed(waveform, speed)
70
+
71
+ # Normalize to 16-bit PCM
72
+ waveform = np.clip(waveform, -1.0, 1.0)
73
+ waveform_int16 = (waveform * 32767).astype(np.int16)
74
+
75
+ # Write to WAV buffer
76
+ buffer = io.BytesIO()
77
+ wavfile.write(buffer, rate=self.sample_rate, data=waveform_int16)
78
+
79
+ return buffer.getvalue()
80
+
81
+ def _change_speed(self, waveform: np.ndarray, speed: float) -> np.ndarray:
82
+ """
83
+ Change playback speed using resampling.
84
+ Speed > 1 = faster (shorter audio)
85
+ Speed < 1 = slower (longer audio)
86
+
87
+ This uses simple linear interpolation for speed change without pitch shift.
88
+ """
89
+ if speed == 1.0:
90
+ return waveform
91
+
92
+ # Calculate new length
93
+ original_length = len(waveform)
94
+ new_length = int(original_length / speed)
95
+
96
+ # Create new time indices
97
+ old_indices = np.arange(original_length)
98
+ new_indices = np.linspace(0, original_length - 1, new_length)
99
+
100
+ # Interpolate
101
+ stretched = np.interp(new_indices, old_indices, waveform)
102
+
103
+ return stretched.astype(np.float32)
104
+
105
+
106
+ # Singleton instance
107
+ @lru_cache(maxsize=1)
108
+ def get_tts_service() -> TTSService:
109
+ return TTSService()