|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import sherpa_onnx |
|
|
import time |
|
|
import os |
|
|
import urllib.request |
|
|
import tarfile |
|
|
|
|
|
|
|
|
model_dir = "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8" |
|
|
if not os.path.exists(model_dir): |
|
|
url = "https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8.tar.bz2" |
|
|
urllib.request.urlretrieve(url, "model.tar.bz2") |
|
|
with tarfile.open("model.tar.bz2") as tar: |
|
|
tar.extractall() |
|
|
os.remove("model.tar.bz2") |
|
|
|
|
|
|
|
|
endpoint_config = sherpa_onnx.EndpointConfig( |
|
|
rule1_min_trailing_silence=1.0, |
|
|
rule2_min_trailing_silence=0.5, |
|
|
rule3_min_utterance_length=30.0 |
|
|
) |
|
|
|
|
|
|
|
|
config = sherpa_onnx.OnlineRecognizerConfig( |
|
|
feat_config=sherpa_onnx.FeatureConfig(sample_rate=16000), |
|
|
model_config=sherpa_onnx.OnlineTransducerModelConfig( |
|
|
encoder=os.path.join(model_dir, "encoder.int8.onnx"), |
|
|
decoder=os.path.join(model_dir, "decoder.int8.onnx"), |
|
|
joiner=os.path.join(model_dir, "joiner.int8.onnx") |
|
|
), |
|
|
tokens=os.path.join(model_dir, "tokens.txt"), |
|
|
provider="cpu", |
|
|
num_threads=2, |
|
|
endpoint_config=endpoint_config |
|
|
) |
|
|
recognizer = sherpa_onnx.OnlineRecognizer(config) |
|
|
|
|
|
def transcribe(state, audio_chunk): |
|
|
if state is None: |
|
|
state = { |
|
|
"stream": recognizer.create_stream(), |
|
|
"transcript": "", |
|
|
"current_partial": "", |
|
|
"log": "", |
|
|
"last_time": time.time() |
|
|
} |
|
|
|
|
|
try: |
|
|
sr, y = audio_chunk |
|
|
if y.ndim > 1: |
|
|
y = np.mean(y, axis=1) |
|
|
y = y.astype(np.float32) |
|
|
if np.max(np.abs(y)) > 0: |
|
|
y /= np.max(np.abs(y)) |
|
|
else: |
|
|
state["log"] += "Weak signal detected.\n" |
|
|
return state, state["transcript"] + state["current_partial"], state["log"] |
|
|
|
|
|
state["stream"].accept_waveform(sr, y) |
|
|
|
|
|
while recognizer.is_ready(state["stream"]): |
|
|
recognizer.decode_stream(state["stream"]) |
|
|
|
|
|
result = recognizer.get_result(state["stream"]) |
|
|
current_text = result.text.strip() |
|
|
|
|
|
if current_text != state["current_partial"]: |
|
|
state["current_partial"] = current_text |
|
|
latency = time.time() - state["last_time"] |
|
|
state["log"] += f"Partial update (latency: {latency:.2f}s): {current_text}\n" |
|
|
state["last_time"] = time.time() |
|
|
|
|
|
if recognizer.is_endpoint(state["stream"]): |
|
|
if current_text: |
|
|
state["transcript"] += current_text + " " |
|
|
state["log"] += f"Endpoint detected, committed: {current_text}\n" |
|
|
recognizer.reset(state["stream"]) |
|
|
state["current_partial"] = "" |
|
|
|
|
|
except Exception as e: |
|
|
state["log"] += f"Error: {str(e)}\n" |
|
|
|
|
|
return state, state["transcript"] + state["current_partial"], state["log"] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Real-Time Multilingual Microphone Transcription") |
|
|
with gr.Row(): |
|
|
audio = gr.Audio(source="microphone", type="numpy", streaming=True, label="Speak here") |
|
|
transcript = gr.Textbox(label="Transcription", interactive=False) |
|
|
logs = gr.Textbox(label="Debug Logs", interactive=False, lines=5) |
|
|
state = gr.State() |
|
|
|
|
|
audio.stream(transcribe, [state, audio], [state, transcript, logs]) |
|
|
|
|
|
demo.launch() |