Yurikks commited on
Commit
a3fa620
·
1 Parent(s): d1848fb

Deploy Yoruba TTS API with facebook/mms-tts-yor

Browse files
Files changed (6) hide show
  1. Dockerfile +31 -0
  2. README.md +75 -6
  3. cache.py +96 -0
  4. main.py +95 -0
  5. requirements.txt +14 -0
  6. tts_service.py +73 -0
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YorubaApp TTS Backend - Hugging Face Spaces
2
+ FROM python:3.11-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for caching
12
+ COPY requirements.txt .
13
+
14
+ # Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Pre-download model at build time (avoids timeout on startup)
18
+ RUN python -c "from transformers import VitsModel, AutoTokenizer; \
19
+ print('Downloading facebook/mms-tts-yor model...'); \
20
+ VitsModel.from_pretrained('facebook/mms-tts-yor'); \
21
+ AutoTokenizer.from_pretrained('facebook/mms-tts-yor'); \
22
+ print('Model downloaded successfully!')"
23
+
24
+ # Copy application code
25
+ COPY . .
26
+
27
+ # Hugging Face Spaces uses port 7860
28
+ EXPOSE 7860
29
+
30
+ # Run the application
31
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,80 @@
1
  ---
2
- title: Yoruba Tts
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
 
7
  pinned: false
8
- short_description: yorubaapp
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Yoruba TTS API
3
+ emoji: "\U0001F5E3\uFE0F"
4
+ colorFrom: yellow
5
+ colorTo: orange
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: cc-by-nc-4.0
10
  ---
11
 
12
+ # Yoruba TTS API
13
+
14
+ Text-to-Speech API for Yoruba language using the `facebook/mms-tts-yor` model.
15
+
16
+ ## Model Information
17
+
18
+ - **Model**: [facebook/mms-tts-yor](https://huggingface.co/facebook/mms-tts-yor)
19
+ - **Architecture**: VITS (Variational Inference TTS)
20
+ - **Parameters**: 36.3M
21
+ - **License**: CC-BY-NC 4.0 (non-commercial use)
22
+
23
+ ## API Endpoints
24
+
25
+ ### POST /tts
26
+
27
+ Generate speech from Yoruba text.
28
+
29
+ **Request:**
30
+ ```json
31
+ {
32
+ "text": "Bawo ni"
33
+ }
34
+ ```
35
+
36
+ **Response:**
37
+ ```json
38
+ {
39
+ "audio": "UklGRiQAAABXQVZFZm10...",
40
+ "cached": false
41
+ }
42
+ ```
43
+
44
+ The `audio` field contains base64-encoded WAV audio.
45
+
46
+ ### GET /health
47
+
48
+ Check service health.
49
+
50
+ **Response:**
51
+ ```json
52
+ {
53
+ "status": "healthy",
54
+ "model": "facebook/mms-tts-yor"
55
+ }
56
+ ```
57
+
58
+ ## Usage Example
59
+
60
+ ```python
61
+ import requests
62
+ import base64
63
+
64
+ response = requests.post(
65
+ "https://YOUR-SPACE.hf.space/tts",
66
+ json={"text": "Bawo ni"}
67
+ )
68
+
69
+ audio_b64 = response.json()["audio"]
70
+ audio_bytes = base64.b64decode(audio_b64)
71
+
72
+ with open("output.wav", "wb") as f:
73
+ f.write(audio_bytes)
74
+ ```
75
+
76
+ ## Limitations
77
+
78
+ - Maximum text length: 500 characters
79
+ - Audio format: WAV (16-bit PCM)
80
+ - Sample rate: Model default (~22050 Hz)
cache.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTS Cache using Redis or in-memory fallback
3
+ """
4
+
5
+ import hashlib
6
+ import logging
7
+ import os
8
+ from typing import Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Try to import redis, fallback to in-memory cache
13
+ try:
14
+ import redis.asyncio as redis
15
+ REDIS_AVAILABLE = True
16
+ except ImportError:
17
+ REDIS_AVAILABLE = False
18
+ logger.warning("Redis not available, using in-memory cache")
19
+
20
+
21
+ class TTSCache:
22
+ def __init__(self):
23
+ self.ttl = 86400 * 7 # 7 days
24
+ self.redis_client = None
25
+ self.memory_cache: dict[str, str] = {}
26
+ self.max_memory_items = 1000
27
+
28
+ # Try to connect to Redis
29
+ redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379")
30
+
31
+ if REDIS_AVAILABLE:
32
+ try:
33
+ self.redis_client = redis.from_url(redis_url, decode_responses=True)
34
+ logger.info(f"Redis cache initialized: {redis_url}")
35
+ except Exception as e:
36
+ logger.warning(f"Redis connection failed, using memory cache: {e}")
37
+ self.redis_client = None
38
+
39
+ def _key(self, text: str) -> str:
40
+ """Generate cache key from text hash"""
41
+ return f"tts:{hashlib.md5(text.encode()).hexdigest()}"
42
+
43
+ async def get(self, text: str) -> Optional[str]:
44
+ """Get cached audio (base64) for text"""
45
+ key = self._key(text)
46
+
47
+ # Try Redis first
48
+ if self.redis_client:
49
+ try:
50
+ result = await self.redis_client.get(key)
51
+ if result:
52
+ logger.debug(f"Redis cache hit for key: {key}")
53
+ return result
54
+ except Exception as e:
55
+ logger.warning(f"Redis get failed: {e}")
56
+
57
+ # Fallback to memory cache
58
+ result = self.memory_cache.get(key)
59
+ if result:
60
+ logger.debug(f"Memory cache hit for key: {key}")
61
+ return result
62
+
63
+ async def set(self, text: str, audio_b64: str):
64
+ """Cache audio (base64) for text"""
65
+ key = self._key(text)
66
+
67
+ # Try Redis first
68
+ if self.redis_client:
69
+ try:
70
+ await self.redis_client.setex(key, self.ttl, audio_b64)
71
+ logger.debug(f"Cached to Redis: {key}")
72
+ return
73
+ except Exception as e:
74
+ logger.warning(f"Redis set failed: {e}")
75
+
76
+ # Fallback to memory cache with LRU eviction
77
+ if len(self.memory_cache) >= self.max_memory_items:
78
+ # Remove oldest item (simple FIFO, not true LRU)
79
+ oldest_key = next(iter(self.memory_cache))
80
+ del self.memory_cache[oldest_key]
81
+ logger.debug(f"Evicted from memory cache: {oldest_key}")
82
+
83
+ self.memory_cache[key] = audio_b64
84
+ logger.debug(f"Cached to memory: {key}")
85
+
86
+ async def clear(self):
87
+ """Clear all cached items"""
88
+ self.memory_cache.clear()
89
+
90
+ if self.redis_client:
91
+ try:
92
+ # Clear only TTS keys
93
+ async for key in self.redis_client.scan_iter("tts:*"):
94
+ await self.redis_client.delete(key)
95
+ except Exception as e:
96
+ logger.warning(f"Redis clear failed: {e}")
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTS Backend for YorubaApp
3
+ Uses facebook/mms-tts-yor model for Yoruba text-to-speech
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ import base64
10
+ import logging
11
+
12
+ from tts_service import TTSService
13
+ from cache import TTSCache
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI(
20
+ title="YorubaApp TTS API",
21
+ description="Text-to-Speech API for Yoruba language using MMS-TTS-YOR",
22
+ version="1.0.0"
23
+ )
24
+
25
+ # CORS - allow requests from Expo dev server and production
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"], # Configure for production
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Initialize services
35
+ tts = TTSService()
36
+ cache = TTSCache()
37
+
38
+
39
+ class TTSRequest(BaseModel):
40
+ text: str
41
+
42
+
43
+ class TTSResponse(BaseModel):
44
+ audio: str # base64 encoded WAV
45
+ cached: bool
46
+
47
+
48
+ @app.get("/")
49
+ async def root():
50
+ return {"status": "ok", "service": "YorubaApp TTS API"}
51
+
52
+
53
+ @app.get("/health")
54
+ async def health():
55
+ return {"status": "healthy", "model": "facebook/mms-tts-yor"}
56
+
57
+
58
+ @app.post("/tts", response_model=TTSResponse)
59
+ async def text_to_speech(request: TTSRequest):
60
+ text = request.text.strip()
61
+
62
+ if not text:
63
+ raise HTTPException(status_code=400, detail="Text is required")
64
+
65
+ if len(text) > 500:
66
+ raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
67
+
68
+ logger.info(f"TTS request for text: {text[:50]}...")
69
+
70
+ # Check cache first
71
+ cached_audio = await cache.get(text)
72
+ if cached_audio:
73
+ logger.info("Returning cached audio")
74
+ return TTSResponse(audio=cached_audio, cached=True)
75
+
76
+ try:
77
+ # Generate audio
78
+ audio_bytes = await tts.synthesize(text)
79
+ audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
80
+
81
+ # Cache result
82
+ await cache.set(text, audio_b64)
83
+
84
+ logger.info(f"Generated audio: {len(audio_bytes)} bytes")
85
+ return TTSResponse(audio=audio_b64, cached=False)
86
+
87
+ except Exception as e:
88
+ logger.error(f"TTS synthesis failed: {e}")
89
+ raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import uvicorn
94
+ # Port 7860 is the default for Hugging Face Spaces
95
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI and server
2
+ fastapi==0.115.6
3
+ uvicorn[standard]==0.34.0
4
+ pydantic==2.10.3
5
+
6
+ # TTS Model (transformers >= 4.33 REQUIRED for MMS-TTS)
7
+ torch>=2.0.0
8
+ transformers>=4.33.0
9
+ accelerate>=0.21.0
10
+ scipy>=1.14.0
11
+ numpy>=1.26.0
12
+
13
+ # Utilities
14
+ python-dotenv>=1.0.0
tts_service.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTS Service using facebook/mms-tts-yor (Yoruba)
3
+ """
4
+
5
+ import io
6
+ import logging
7
+ import asyncio
8
+ from functools import lru_cache
9
+
10
+ import torch
11
+ import numpy as np
12
+ import scipy.io.wavfile as wavfile
13
+ from transformers import VitsModel, AutoTokenizer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TTSService:
19
+ def __init__(self):
20
+ logger.info("Loading MMS-TTS-YOR model...")
21
+
22
+ # Load model and tokenizer
23
+ self.model = VitsModel.from_pretrained("facebook/mms-tts-yor")
24
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor")
25
+
26
+ # Set to evaluation mode
27
+ self.model.eval()
28
+
29
+ # Use GPU if available
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ self.model = self.model.to(self.device)
32
+
33
+ logger.info(f"Model loaded on {self.device}")
34
+ logger.info(f"Sampling rate: {self.model.config.sampling_rate}")
35
+
36
+ async def synthesize(self, text: str) -> bytes:
37
+ """
38
+ Synthesize speech from Yoruba text.
39
+ Returns WAV audio bytes.
40
+ """
41
+ # Run synthesis in thread pool to avoid blocking
42
+ loop = asyncio.get_event_loop()
43
+ return await loop.run_in_executor(None, self._synthesize_sync, text)
44
+
45
+ def _synthesize_sync(self, text: str) -> bytes:
46
+ """Synchronous synthesis (runs in thread pool)"""
47
+
48
+ # Tokenize input
49
+ inputs = self.tokenizer(text, return_tensors="pt")
50
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
51
+
52
+ # Generate audio
53
+ with torch.no_grad():
54
+ output = self.model(**inputs).waveform
55
+
56
+ # Convert to numpy
57
+ waveform = output.squeeze().cpu().numpy()
58
+
59
+ # Normalize to 16-bit PCM
60
+ waveform = np.clip(waveform, -1.0, 1.0)
61
+ waveform_int16 = (waveform * 32767).astype(np.int16)
62
+
63
+ # Write to WAV buffer
64
+ buffer = io.BytesIO()
65
+ wavfile.write(buffer, rate=self.model.config.sampling_rate, data=waveform_int16)
66
+
67
+ return buffer.getvalue()
68
+
69
+
70
+ # Singleton instance
71
+ @lru_cache(maxsize=1)
72
+ def get_tts_service() -> TTSService:
73
+ return TTSService()