Spaces:
Sleeping
Sleeping
| """ | |
| TTS Service using facebook/mms-tts-yor (Yoruba) | |
| Supports variable speed playback (normal and slow) | |
| """ | |
| import io | |
| import logging | |
| import asyncio | |
| from functools import lru_cache | |
| import torch | |
| import numpy as np | |
| import scipy.io.wavfile as wavfile | |
| from transformers import VitsModel, AutoTokenizer | |
| logger = logging.getLogger(__name__) | |
| class TTSService: | |
| def __init__(self): | |
| logger.info("Loading MMS-TTS-YOR model...") | |
| # Load model and tokenizer | |
| self.model = VitsModel.from_pretrained("facebook/mms-tts-yor") | |
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-yor") | |
| # Set to evaluation mode | |
| self.model.eval() | |
| # Use GPU if available | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = self.model.to(self.device) | |
| self.sample_rate = self.model.config.sampling_rate | |
| logger.info(f"Model loaded on {self.device}") | |
| logger.info(f"Sampling rate: {self.sample_rate}") | |
| async def synthesize(self, text: str, speed: float = 1.0) -> bytes: | |
| """ | |
| Synthesize speech from Yoruba text. | |
| Args: | |
| text: Text to synthesize | |
| speed: Playback speed (0.5 = half speed, 1.0 = normal, 1.5 = faster) | |
| Returns WAV audio bytes. | |
| """ | |
| # Run synthesis in thread pool to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, self._synthesize_sync, text, speed) | |
| def _synthesize_sync(self, text: str, speed: float = 1.0) -> bytes: | |
| """Synchronous synthesis (runs in thread pool)""" | |
| # Tokenize input | |
| inputs = self.tokenizer(text, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate audio | |
| with torch.no_grad(): | |
| output = self.model(**inputs).waveform | |
| # Convert to numpy | |
| waveform = output.squeeze().cpu().numpy() | |
| # Apply time-stretching for speed change (using simple resampling) | |
| if speed != 1.0 and speed > 0: | |
| waveform = self._change_speed(waveform, speed) | |
| # Normalize to 16-bit PCM | |
| waveform = np.clip(waveform, -1.0, 1.0) | |
| waveform_int16 = (waveform * 32767).astype(np.int16) | |
| # Write to WAV buffer | |
| buffer = io.BytesIO() | |
| wavfile.write(buffer, rate=self.sample_rate, data=waveform_int16) | |
| return buffer.getvalue() | |
| def _change_speed(self, waveform: np.ndarray, speed: float) -> np.ndarray: | |
| """ | |
| Change playback speed using resampling. | |
| Speed > 1 = faster (shorter audio) | |
| Speed < 1 = slower (longer audio) | |
| This uses simple linear interpolation for speed change without pitch shift. | |
| """ | |
| if speed == 1.0: | |
| return waveform | |
| # Calculate new length | |
| original_length = len(waveform) | |
| new_length = int(original_length / speed) | |
| # Create new time indices | |
| old_indices = np.arange(original_length) | |
| new_indices = np.linspace(0, original_length - 1, new_length) | |
| # Interpolate | |
| stretched = np.interp(new_indices, old_indices, waveform) | |
| return stretched.astype(np.float32) | |
| # Singleton instance | |
| def get_tts_service() -> TTSService: | |
| return TTSService() | |