""" TTS Service using facebook/mms-tts-yor (Yoruba) Supports variable speed playback (normal and slow) """ import io import logging import asyncio from functools import lru_cache import torch import numpy as np import scipy.io.wavfile as wavfile from transformers import VitsModel, AutoTokenizer logger = logging.getLogger(__name__) class TTSService: def __init__(self): logger.info("Loading MMS-TTS-YOR model...") # Load model and tokenizer self.model = VitsModel.from_pretrained("facebook/mms-tts-yor") self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor") # Set to evaluation mode self.model.eval() # Use GPU if available self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = self.model.to(self.device) self.sample_rate = self.model.config.sampling_rate logger.info(f"Model loaded on {self.device}") logger.info(f"Sampling rate: {self.sample_rate}") async def synthesize(self, text: str, speed: float = 1.0) -> bytes: """ Synthesize speech from Yoruba text. Args: text: Text to synthesize speed: Playback speed (0.5 = half speed, 1.0 = normal, 1.5 = faster) Returns WAV audio bytes. """ # Run synthesis in thread pool to avoid blocking loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._synthesize_sync, text, speed) def _synthesize_sync(self, text: str, speed: float = 1.0) -> bytes: """Synchronous synthesis (runs in thread pool)""" # Tokenize input inputs = self.tokenizer(text, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate audio with torch.no_grad(): output = self.model(**inputs).waveform # Convert to numpy waveform = output.squeeze().cpu().numpy() # Apply time-stretching for speed change (using simple resampling) if speed != 1.0 and speed > 0: waveform = self._change_speed(waveform, speed) # Normalize to 16-bit PCM waveform = np.clip(waveform, -1.0, 1.0) waveform_int16 = (waveform * 32767).astype(np.int16) # Write to WAV buffer buffer = io.BytesIO() wavfile.write(buffer, rate=self.sample_rate, data=waveform_int16) return buffer.getvalue() def _change_speed(self, waveform: np.ndarray, speed: float) -> np.ndarray: """ Change playback speed using resampling. Speed > 1 = faster (shorter audio) Speed < 1 = slower (longer audio) This uses simple linear interpolation for speed change without pitch shift. """ if speed == 1.0: return waveform # Calculate new length original_length = len(waveform) new_length = int(original_length / speed) # Create new time indices old_indices = np.arange(original_length) new_indices = np.linspace(0, original_length - 1, new_length) # Interpolate stretched = np.interp(new_indices, old_indices, waveform) return stretched.astype(np.float32) # Singleton instance @lru_cache(maxsize=1) def get_tts_service() -> TTSService: return TTSService()