yoruba-tts / tts_service.py
Yurikks's picture
Upload 2 files
4601f86 verified
"""
TTS Service using facebook/mms-tts-yor (Yoruba)
Supports variable speed playback (normal and slow)
"""
import io
import logging
import asyncio
from functools import lru_cache
import torch
import numpy as np
import scipy.io.wavfile as wavfile
from transformers import VitsModel, AutoTokenizer
logger = logging.getLogger(__name__)
class TTSService:
def __init__(self):
logger.info("Loading MMS-TTS-YOR model...")
# Load model and tokenizer
self.model = VitsModel.from_pretrained("facebook/mms-tts-yor")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor")
# Set to evaluation mode
self.model.eval()
# Use GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)
self.sample_rate = self.model.config.sampling_rate
logger.info(f"Model loaded on {self.device}")
logger.info(f"Sampling rate: {self.sample_rate}")
async def synthesize(self, text: str, speed: float = 1.0) -> bytes:
"""
Synthesize speech from Yoruba text.
Args:
text: Text to synthesize
speed: Playback speed (0.5 = half speed, 1.0 = normal, 1.5 = faster)
Returns WAV audio bytes.
"""
# Run synthesis in thread pool to avoid blocking
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._synthesize_sync, text, speed)
def _synthesize_sync(self, text: str, speed: float = 1.0) -> bytes:
"""Synchronous synthesis (runs in thread pool)"""
# Tokenize input
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate audio
with torch.no_grad():
output = self.model(**inputs).waveform
# Convert to numpy
waveform = output.squeeze().cpu().numpy()
# Apply time-stretching for speed change (using simple resampling)
if speed != 1.0 and speed > 0:
waveform = self._change_speed(waveform, speed)
# Normalize to 16-bit PCM
waveform = np.clip(waveform, -1.0, 1.0)
waveform_int16 = (waveform * 32767).astype(np.int16)
# Write to WAV buffer
buffer = io.BytesIO()
wavfile.write(buffer, rate=self.sample_rate, data=waveform_int16)
return buffer.getvalue()
def _change_speed(self, waveform: np.ndarray, speed: float) -> np.ndarray:
"""
Change playback speed using resampling.
Speed > 1 = faster (shorter audio)
Speed < 1 = slower (longer audio)
This uses simple linear interpolation for speed change without pitch shift.
"""
if speed == 1.0:
return waveform
# Calculate new length
original_length = len(waveform)
new_length = int(original_length / speed)
# Create new time indices
old_indices = np.arange(original_length)
new_indices = np.linspace(0, original_length - 1, new_length)
# Interpolate
stretched = np.interp(new_indices, old_indices, waveform)
return stretched.astype(np.float32)
# Singleton instance
@lru_cache(maxsize=1)
def get_tts_service() -> TTSService:
return TTSService()