Spaces:
Sleeping
Sleeping
| """ | |
| YorubaApp STT API - Speech-to-Text para Yoruba usando Facebook MMS-1b-all | |
| Deploy: Hugging Face Spaces (Docker) | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| import torch | |
| import numpy as np | |
| import base64 | |
| import logging | |
| import asyncio | |
| import io | |
| import tempfile | |
| import os | |
| from typing import Optional | |
| import soundfile as sf | |
| import librosa | |
| # Configurar logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="YorubaApp STT API", | |
| description="Speech-to-Text para Yoruba usando Facebook MMS-1b-all", | |
| version="1.0.0" | |
| ) | |
| # CORS para permitir conexões do app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Variáveis globais para o modelo | |
| processor: Optional[AutoProcessor] = None | |
| model: Optional[Wav2Vec2ForCTC] = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| async def load_model(): | |
| """Carrega o modelo MMS na inicialização""" | |
| global processor, model | |
| logger.info("Carregando modelo MMS-1b-all...") | |
| MODEL_ID = "facebook/mms-1b-all" | |
| try: | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) | |
| # Configurar para Yoruba | |
| processor.tokenizer.set_target_lang("yor") | |
| model.load_adapter("yor") | |
| model.to(device) | |
| model.eval() | |
| logger.info(f"Modelo carregado com sucesso! Device: {device}") | |
| except Exception as e: | |
| logger.error(f"Erro ao carregar modelo: {e}") | |
| raise | |
| def process_audio(audio_data: bytes) -> str: | |
| """Processa áudio e retorna transcrição""" | |
| global processor, model | |
| if processor is None or model is None: | |
| raise RuntimeError("Modelo não carregado") | |
| try: | |
| # Tenta detectar e converter o formato do áudio | |
| audio_np = convert_audio_to_pcm(audio_data) | |
| if audio_np is None or len(audio_np) == 0: | |
| logger.warning("Áudio vazio após conversão") | |
| return "" | |
| # Verifica se há áudio suficiente | |
| if len(audio_np) < 1600: # Menos de 0.1s | |
| logger.warning(f"Áudio muito curto: {len(audio_np)} samples") | |
| return "" | |
| logger.info(f"Processando {len(audio_np)} samples de áudio") | |
| # Processa com MMS | |
| inputs = processor( | |
| audio_np, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| transcription = processor.decode(ids) | |
| return transcription.strip() | |
| except Exception as e: | |
| logger.error(f"Erro no processamento de áudio: {e}") | |
| return "" | |
| def convert_audio_to_pcm(audio_data: bytes) -> Optional[np.ndarray]: | |
| """ | |
| Converte áudio de qualquer formato (WebM, MP3, M4A, WAV) para PCM 16kHz mono float32. | |
| Retorna numpy array normalizado [-1.0, 1.0] ou None em caso de erro. | |
| """ | |
| try: | |
| # Primeiro, tenta detectar se é PCM raw (como antes) | |
| # PCM 16-bit geralmente tem tamanho par e valores válidos | |
| if len(audio_data) >= 3200 and len(audio_data) % 2 == 0: | |
| # Tenta como PCM 16-bit | |
| try: | |
| pcm_data = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| # Verifica se parece áudio válido (não apenas ruído) | |
| if np.abs(pcm_data).max() < 0.01: | |
| logger.info("PCM parece silêncio, tentando como arquivo codificado") | |
| else: | |
| # Verifica se tem variação razoável | |
| std = np.std(pcm_data) | |
| if std > 0.001 and std < 0.5: | |
| logger.info("Detectado como PCM raw 16-bit") | |
| return pcm_data | |
| except: | |
| pass | |
| # Tenta carregar como arquivo de áudio codificado (WebM, MP3, M4A, etc.) | |
| # Usa arquivo temporário porque algumas bibliotecas não aceitam bytes diretamente | |
| with tempfile.NamedTemporaryFile(suffix='.audio', delete=False) as tmp: | |
| tmp.write(audio_data) | |
| tmp_path = tmp.name | |
| try: | |
| # Usa librosa para carregar e resamplear automaticamente para 16kHz | |
| audio_np, sr = librosa.load(tmp_path, sr=16000, mono=True) | |
| logger.info(f"Áudio carregado via librosa: {len(audio_np)} samples, sr={sr}") | |
| return audio_np | |
| except Exception as e1: | |
| logger.warning(f"librosa falhou: {e1}") | |
| # Fallback: tenta com soundfile | |
| try: | |
| audio_np, sr = sf.read(io.BytesIO(audio_data)) | |
| # Converte para mono se estéreo | |
| if len(audio_np.shape) > 1: | |
| audio_np = np.mean(audio_np, axis=1) | |
| # Resamplea para 16kHz se necessário | |
| if sr != 16000: | |
| audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=16000) | |
| logger.info(f"Áudio carregado via soundfile: {len(audio_np)} samples") | |
| return audio_np.astype(np.float32) | |
| except Exception as e2: | |
| logger.warning(f"soundfile falhou: {e2}") | |
| # Último fallback: assume PCM raw | |
| logger.info("Fallback para PCM raw") | |
| return np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| finally: | |
| # Limpa arquivo temporário | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| logger.error(f"Erro na conversão de áudio: {e}") | |
| return None | |
| # ============== WebSocket Endpoint (Streaming) ============== | |
| async def websocket_transcribe(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint para transcrição em streaming. | |
| Protocolo: | |
| - Cliente envia: {"type": "audio_chunk", "audio": "<base64>"} | |
| - Cliente envia: {"type": "end"} para finalizar | |
| - Servidor responde: {"type": "partial", "text": "..."} durante streaming | |
| - Servidor responde: {"type": "final", "text": "..."} ao finalizar | |
| """ | |
| await websocket.accept() | |
| audio_buffer = bytearray() | |
| chunk_count = 0 | |
| logger.info("WebSocket conectado") | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| msg_type = data.get("type") | |
| if msg_type == "audio_chunk": | |
| # Decodifica chunk base64 e adiciona ao buffer | |
| try: | |
| chunk = base64.b64decode(data.get("audio", "")) | |
| audio_buffer.extend(chunk) | |
| chunk_count += 1 | |
| logger.debug(f"Chunk {chunk_count}: +{len(chunk)} bytes, total: {len(audio_buffer)}") | |
| # Para streaming, processar a cada ~2 segundos de áudio (estimado) | |
| # Mas só para partial results - o final será processado no "end" | |
| if len(audio_buffer) >= 64000 and chunk_count % 5 == 0: | |
| transcription = process_audio(bytes(audio_buffer)) | |
| if transcription: | |
| await websocket.send_json({ | |
| "type": "partial", | |
| "text": transcription | |
| }) | |
| logger.info(f"Parcial: {transcription}") | |
| except Exception as e: | |
| logger.error(f"Erro ao processar chunk: {e}") | |
| elif msg_type == "end": | |
| # Processa áudio final completo | |
| logger.info(f"Finalizando - {chunk_count} chunks, {len(audio_buffer)} bytes") | |
| if audio_buffer: | |
| final_text = process_audio(bytes(audio_buffer)) | |
| logger.info(f"Transcrição final: '{final_text}' (len={len(final_text)})") | |
| await websocket.send_json({ | |
| "type": "final", | |
| "text": final_text | |
| }) | |
| else: | |
| logger.warning("Buffer vazio no end") | |
| await websocket.send_json({ | |
| "type": "final", | |
| "text": "" | |
| }) | |
| # Limpa buffer para próxima sessão | |
| audio_buffer = bytearray() | |
| chunk_count = 0 | |
| elif msg_type == "ping": | |
| # Keep-alive | |
| await websocket.send_json({"type": "pong"}) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket desconectado") | |
| except Exception as e: | |
| logger.error(f"Erro no WebSocket: {e}") | |
| # ============== REST Endpoint (Fallback) ============== | |
| class STTRequest(BaseModel): | |
| audio_base64: str | |
| class STTResponse(BaseModel): | |
| text: str | |
| success: bool | |
| error: Optional[str] = None | |
| async def speech_to_text(request: STTRequest): | |
| """ | |
| Endpoint REST para transcrição (fallback se WebSocket não funcionar). | |
| Body: | |
| - audio_base64: Áudio em base64 (PCM 16-bit, 16kHz, mono) | |
| Returns: | |
| - text: Transcrição em Yoruba | |
| - success: Se a operação foi bem sucedida | |
| """ | |
| try: | |
| audio_data = base64.b64decode(request.audio_base64) | |
| transcription = process_audio(audio_data) | |
| return STTResponse( | |
| text=transcription, | |
| success=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erro no endpoint REST: {e}") | |
| return STTResponse( | |
| text="", | |
| success=False, | |
| error=str(e) | |
| ) | |
| # ============== Health Check ============== | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "device": device | |
| } | |
| async def root(): | |
| """Root endpoint com informações da API""" | |
| return { | |
| "name": "YorubaApp STT API", | |
| "version": "1.0.0", | |
| "model": "facebook/mms-1b-all", | |
| "language": "Yoruba (yor)", | |
| "endpoints": { | |
| "websocket": "/ws/transcribe", | |
| "rest": "/stt", | |
| "health": "/health" | |
| } | |
| } | |