刘鑫 commited on
Commit
384093d
·
1 Parent(s): e39dcd8

update serve model to voxcpm 1.5

Browse files
Files changed (1) hide show
  1. app.py +162 -92
app.py CHANGED
@@ -39,39 +39,39 @@ logger.info("🚀 VoxCPM应用启动中...")
39
  logger.info(f"Python版本: {sys.version}")
40
  logger.info(f"工作目录: {os.getcwd()}")
41
  logger.info(f"环境变量PORT: {os.environ.get('PORT', '未设置')}")
42
- logger.info(f"环境变量RAY_SERVE_URL: {os.environ.get('RAY_SERVE_URL', '未设置')}")
43
  logger.info("="*50)
44
 
45
 
46
- class RayServeVoxCPMClient:
47
- """Client wrapper that talks to Ray Serve TTS API."""
48
 
49
  def __init__(self) -> None:
50
- logger.info("📡 初始化RayServeVoxCPMClient...")
51
 
52
  try:
53
- # Ray Serve API URL (can be overridden via env)
54
- self.RAY_SERVE_DEFAULT_URL = "https://d09181959-pytorch251-cuda124-u-5512-sj7yq0o5-8970.550w.link"
55
  self.api_url = self._resolve_server_url()
56
- logger.info(f"🔗 准备连接到Ray Serve API: {self.api_url}")
57
 
58
  # Test connection
59
- logger.info("⏳ 测试Ray Serve连接...")
60
  health_start = time.time()
61
  health_response = requests.get(f"{self.api_url}/health", timeout=10)
62
  health_response.raise_for_status()
63
  health_time = time.time() - health_start
64
- logger.info(f"✅ 成功连接到Ray Serve API: {self.api_url} (耗时: {health_time:.3f}秒)")
65
 
66
  except Exception as e:
67
- logger.error(f"❌ 初始化RayServeVoxCPMClient失败: {e}")
68
  logger.error(f"错误详情: {traceback.format_exc()}")
69
  raise
70
 
71
  # ----------- Helpers -----------
72
  def _resolve_server_url(self) -> str:
73
- """Resolve Ray Serve API base URL, prefer env RAY_SERVE_URL."""
74
- return os.environ.get("RAY_SERVE_URL", self.RAY_SERVE_DEFAULT_URL).rstrip("/")
75
 
76
  def _audio_file_to_base64(self, audio_file_path: str) -> str:
77
  """
@@ -136,7 +136,7 @@ class RayServeVoxCPMClient:
136
 
137
  # ----------- Functional endpoints -----------
138
  def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
139
- """Use Ray Serve ASR API for speech recognition."""
140
  logger.info(f"🎵 开始语音识别,输入文件: {prompt_wav}")
141
 
142
  if prompt_wav is None or not prompt_wav.strip():
@@ -151,12 +151,9 @@ class RayServeVoxCPMClient:
151
  audio_base64 = self._audio_file_to_base64(prompt_wav)
152
  convert_time = time.time() - convert_start
153
 
154
- # 构建ASR请求 - 匹配 voxcpm_api.py 格式
155
  asr_request = {
156
- "audio_data": audio_base64,
157
- "language": "auto",
158
- "use_itn": True,
159
- "reqid": str(uuid.uuid4())
160
  }
161
 
162
  # 调用ASR接口
@@ -177,19 +174,15 @@ class RayServeVoxCPMClient:
177
  logger.info(f"⏱️ ASR总耗时: {total_time:.3f}秒")
178
  logger.info(f"🔍 完整的ASR响应: {result_data}")
179
 
180
- # 检查响应状态 - 基于实际响应格式,ASR有多种成功标识
181
- if isinstance(result_data, dict) and "text" in result_data and (
182
- result_data.get("code") == 3000 or result_data.get("status") == "ok"
183
- ):
184
  recognized_text = result_data.get("text", "")
185
  logger.info(f"🎯 识别结果: '{recognized_text}'")
186
  return recognized_text
187
  else:
188
  logger.warning(f"⚠️ ASR响应验证失败:")
189
  if isinstance(result_data, dict):
190
- logger.warning(f" - code字段: {result_data.get('code')}")
191
  logger.warning(f" - 是否有text字段: {'text' in result_data}")
192
- logger.warning(f" - message字段: {result_data.get('message')}")
193
  logger.warning(f"⚠️ 完整ASR响应: {result_data}")
194
  return ""
195
 
@@ -198,7 +191,7 @@ class RayServeVoxCPMClient:
198
  logger.error(f"错误详情: {traceback.format_exc()}")
199
  return ""
200
 
201
- def _call_ray_serve_generate(
202
  self,
203
  text: str,
204
  prompt_wav_path: Optional[str] = None,
@@ -209,83 +202,157 @@ class RayServeVoxCPMClient:
209
  denoise: bool = True,
210
  ) -> Tuple[int, np.ndarray]:
211
  """
212
- Call Ray Serve /generate API and return (sample_rate, waveform).
 
 
 
213
  """
214
  try:
215
  start_time = time.time()
216
 
217
- # 构建请求数据 - 匹配 voxcpm_api.py 格式
218
- prepare_start = time.time()
219
- request_data = {
220
- "text": text,
221
- "cfg_value": cfg_value,
222
- "inference_timesteps": inference_timesteps,
223
- "do_normalize": do_normalize,
224
- "denoise": denoise,
225
- "reqid": str(uuid.uuid4())
226
- }
227
-
228
- # 如果有参考音频和文本,添加到请求中
229
- if prompt_wav_path and prompt_text:
230
- logger.info("🎭 使用语音克隆模式")
231
  convert_start = time.time()
232
  audio_base64 = self._audio_file_to_base64(prompt_wav_path)
233
  convert_time = time.time() - convert_start
 
234
 
235
- request_data.update({
236
- "prompt_wav": audio_base64,
237
- "prompt_text": prompt_text
238
- })
239
- else:
240
- logger.info("🎤 使用默认语音模式")
241
- prepare_time = time.time() - prepare_start
242
-
243
- # 调用生成接口
244
- api_start = time.time()
245
- response = requests.post(
246
- f"{self.api_url}/generate",
247
- json=request_data,
248
- headers={"Content-Type": "application/json"},
249
- timeout=120 # TTS可能需要较长时间
250
- )
251
- response.raise_for_status()
252
- api_time = time.time() - api_start
253
-
254
- result_data = response.json()
255
-
256
- # 检查响应状态 - 基于实际响应格式,TTS响应没有code字段,只检查data
257
- if isinstance(result_data, dict) and "data" in result_data and isinstance(result_data["data"], str) and result_data["data"]:
258
- # 成功生成音频
259
- audio_base64 = result_data["data"]
260
 
261
- # 将base64音频转换为numpy数组
262
- decode_start = time.time()
263
- sample_rate, audio_array = self._base64_to_audio_array(audio_base64)
264
- decode_time = time.time() - decode_start
265
- total_time = time.time() - start_time
266
 
267
- logger.info(f"📈 性能指标: API={api_time:.3f}s, 解码={decode_time:.3f}s, 总计={total_time:.3f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- return sample_rate, audio_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  else:
271
- logger.error(f"❌ 响应验证失败:")
272
- logger.error(f" - 是否为字典: {isinstance(result_data, dict)}")
273
- if isinstance(result_data, dict):
274
- logger.error(f" - 是否有data字段: {'data' in result_data}")
275
- if "data" in result_data:
276
- logger.error(f" - data字段类型: {type(result_data['data'])}")
277
- logger.error(f" - data字段是否为字符串: {isinstance(result_data['data'], str)}")
278
- if isinstance(result_data['data'], str):
279
- logger.error(f" - data字段是否非空: {bool(result_data['data'])}")
280
- logger.error(f" - data字段长度: {len(result_data['data'])}")
281
- logger.error(f" 完整响应内容: {result_data}")
282
- raise RuntimeError(f"Ray Serve没有返回有效的音频数据。响应: {result_data}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  except requests.exceptions.RequestException as e:
285
- logger.error(f"❌ Ray Serve请求失败: {e}")
286
- raise RuntimeError(f"Failed to connect Ray Serve TTS service: {e}. Check RAY_SERVE_URL='{self.api_url}' and service status")
287
  except Exception as e:
288
- logger.error(f"❌ Ray Serve调用异常: {e}")
289
  raise
290
 
291
  def generate_tts_audio(
@@ -316,7 +383,7 @@ class RayServeVoxCPMClient:
316
  cfg_value = cfg_value_input if cfg_value_input is not None else 2.0
317
  inference_timesteps = inference_timesteps_input if inference_timesteps_input is not None else 10
318
 
319
- sr, wav_np = self._call_ray_serve_generate(
320
  text=text,
321
  prompt_wav_path=prompt_wav_path,
322
  prompt_text=prompt_text,
@@ -336,7 +403,7 @@ class RayServeVoxCPMClient:
336
 
337
  # ---------- UI Builders ----------
338
 
339
- def create_demo_interface(client: RayServeVoxCPMClient):
340
  """Build the Gradio UI for Gradio API VoxCPM client."""
341
  logger.info("🎨 开始创建Gradio界面...")
342
 
@@ -377,6 +444,9 @@ def create_demo_interface(client: RayServeVoxCPMClient):
377
  """
378
  ) as interface:
379
  gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm-logo.png" alt="VoxCPM Logo"></div>')
 
 
 
380
 
381
  # Quick Start
382
  with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
@@ -497,9 +567,9 @@ def run_demo():
497
 
498
  try:
499
  # 创建客户端
500
- logger.info("📡 创建Ray Serve API客户端...")
501
- client = RayServeVoxCPMClient()
502
- logger.info("✅ Ray Serve API客户端创建成功")
503
 
504
  # 创建界面
505
  logger.info("🎨 创建Gradio界面...")
 
39
  logger.info(f"Python版本: {sys.version}")
40
  logger.info(f"工作目录: {os.getcwd()}")
41
  logger.info(f"环境变量PORT: {os.environ.get('PORT', '未设置')}")
42
+ logger.info(f"环境变量VOXCPM_API_URL: {os.environ.get('VOXCPM_API_URL', '未设置')}")
43
  logger.info("="*50)
44
 
45
 
46
+ class VoxCPMClient:
47
+ """Client wrapper that talks to VoxCPM FastAPI server."""
48
 
49
  def __init__(self) -> None:
50
+ logger.info("📡 初始化VoxCPMClient...")
51
 
52
  try:
53
+ # VoxCPM API URL (can be overridden via env)
54
+ self.DEFAULT_API_URL = "https://deployment-5512-xjbzp8ey-7860.550w.link"
55
  self.api_url = self._resolve_server_url()
56
+ logger.info(f"🔗 准备连接到VoxCPM API: {self.api_url}")
57
 
58
  # Test connection
59
+ logger.info("⏳ 测试API连接...")
60
  health_start = time.time()
61
  health_response = requests.get(f"{self.api_url}/health", timeout=10)
62
  health_response.raise_for_status()
63
  health_time = time.time() - health_start
64
+ logger.info(f"✅ 成功连接到VoxCPM API: {self.api_url} (耗时: {health_time:.3f}秒)")
65
 
66
  except Exception as e:
67
+ logger.error(f"❌ 初始化VoxCPMClient失败: {e}")
68
  logger.error(f"错误详情: {traceback.format_exc()}")
69
  raise
70
 
71
  # ----------- Helpers -----------
72
  def _resolve_server_url(self) -> str:
73
+ """Resolve VoxCPM API base URL, prefer env VOXCPM_API_URL."""
74
+ return os.environ.get("VOXCPM_API_URL", self.DEFAULT_API_URL).rstrip("/")
75
 
76
  def _audio_file_to_base64(self, audio_file_path: str) -> str:
77
  """
 
136
 
137
  # ----------- Functional endpoints -----------
138
  def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
139
+ """Use VoxCPM ASR API for speech recognition."""
140
  logger.info(f"🎵 开始语音识别,输入文件: {prompt_wav}")
141
 
142
  if prompt_wav is None or not prompt_wav.strip():
 
151
  audio_base64 = self._audio_file_to_base64(prompt_wav)
152
  convert_time = time.time() - convert_start
153
 
154
+ # 构建ASR请求 - 匹配 advanced_api 格式
155
  asr_request = {
156
+ "wav_base64": audio_base64
 
 
 
157
  }
158
 
159
  # 调用ASR接口
 
174
  logger.info(f"⏱️ ASR总耗时: {total_time:.3f}秒")
175
  logger.info(f"🔍 完整的ASR响应: {result_data}")
176
 
177
+ # 检查响应状态 - advanced_api 格式返回 {"text": "识别文本"}
178
+ if isinstance(result_data, dict) and "text" in result_data:
 
 
179
  recognized_text = result_data.get("text", "")
180
  logger.info(f"🎯 识别结果: '{recognized_text}'")
181
  return recognized_text
182
  else:
183
  logger.warning(f"⚠️ ASR响应验证失败:")
184
  if isinstance(result_data, dict):
 
185
  logger.warning(f" - 是否有text字段: {'text' in result_data}")
 
186
  logger.warning(f"⚠️ 完整ASR响应: {result_data}")
187
  return ""
188
 
 
191
  logger.error(f"错误详情: {traceback.format_exc()}")
192
  return ""
193
 
194
+ def _call_api_generate(
195
  self,
196
  text: str,
197
  prompt_wav_path: Optional[str] = None,
 
202
  denoise: bool = True,
203
  ) -> Tuple[int, np.ndarray]:
204
  """
205
+ Call VoxCPM API and return (sample_rate, waveform).
206
+ 根据是否有 prompt audio 调用不同接口:
207
+ - 有 prompt: /generate_with_prompt(不注册,避免内存问题)
208
+ - 无 prompt: /generate_playground(使用默认音色)
209
  """
210
  try:
211
  start_time = time.time()
212
 
213
+ # 根据是否有参考音频选择不同的接口
214
+ if prompt_wav_path and os.path.exists(prompt_wav_path):
215
+ # ========== 有 prompt: 使用 /generate_with_prompt ==========
216
+ logger.info("🎭 使用语音克隆模式 - 调用 /generate_with_prompt")
217
+
218
+ # 转换音频为 base64
 
 
 
 
 
 
 
 
219
  convert_start = time.time()
220
  audio_base64 = self._audio_file_to_base64(prompt_wav_path)
221
  convert_time = time.time() - convert_start
222
+ logger.info(f"⏱️ 音频转换耗时: {convert_time:.3f}秒")
223
 
224
+ # 使用纯 JSON 请求(方式A),通过 wav_base64 传递音频
225
+ request_data = {
226
+ "target_text": text,
227
+ "wav_base64": audio_base64,
228
+ "prompt_text": prompt_text or "",
229
+ "denoise": denoise,
230
+ "register": False, # 不持久化,避免内存问题
231
+ "audio_format": "wav",
232
+ "max_generate_length": 2000,
233
+ "temperature": 1.0,
234
+ "cfg_value": cfg_value,
235
+ "stream": False
236
+ }
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ api_endpoint = f"{self.api_url}/generate_with_prompt"
 
 
 
 
239
 
240
+ api_start = time.time()
241
+ logger.info(f"📤 请求接口: {api_endpoint}")
242
+ response = requests.post(
243
+ api_endpoint,
244
+ json=request_data,
245
+ timeout=120
246
+ )
247
+ api_time = time.time() - api_start
248
+ logger.info(f"⏱️ API请求耗时: {api_time:.3f}秒")
249
+
250
+ # 打印详细错误信息
251
+ if response.status_code != 200:
252
+ logger.error(f"❌ API返回状态码: {response.status_code}")
253
+ logger.error(f"❌ API返回内容: {response.text}")
254
+ response.raise_for_status()
255
 
256
+ # /generate_with_prompt 返回 WAV 文件
257
+ content_type = response.headers.get("Content-Type", "")
258
+ if "audio/wav" in content_type:
259
+ logger.info("📥 收到 WAV 音频响应")
260
+ audio_bytes = response.content
261
+
262
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
263
+ tmp_file.write(audio_bytes)
264
+ tmp_file_path = tmp_file.name
265
+
266
+ try:
267
+ audio_data, sr = sf.read(tmp_file_path, dtype='float32')
268
+ if audio_data.ndim == 2:
269
+ audio_data = audio_data[:, 0]
270
+ audio_int16 = (audio_data * 32767).astype(np.int16)
271
+
272
+ total_time = time.time() - start_time
273
+ logger.info(f"📈 性能指标: API={api_time:.3f}s, 总计={total_time:.3f}s")
274
+
275
+ return sr, audio_int16
276
+ finally:
277
+ try:
278
+ os.unlink(tmp_file_path)
279
+ except:
280
+ pass
281
+ else:
282
+ # 可能返回 JSON 错误
283
+ result_data = response.json()
284
+ raise RuntimeError(f"API错误: {result_data}")
285
+
286
  else:
287
+ # ========== 无 prompt: 使用 /generate_playground ==========
288
+ logger.info("🎤 使用默认语音模式 - 调用 /generate_playground")
289
+ reqid = str(uuid.uuid4())
290
+
291
+ # 构建嵌套结构请求
292
+ request_data = {
293
+ "audio": {
294
+ "voice_type": "default", # 使用默认音色
295
+ "encoding": "wav",
296
+ "speed_ratio": 1.0,
297
+ "prompt_wav": None,
298
+ "prompt_wav_url": None,
299
+ "prompt_text": "",
300
+ "cfg_value": cfg_value,
301
+ "inference_timesteps": inference_timesteps
302
+ },
303
+ "request": {
304
+ "reqid": reqid,
305
+ "text": text,
306
+ "operation": "query",
307
+ "do_normalize": do_normalize,
308
+ "denoise": denoise
309
+ }
310
+ }
311
+
312
+ api_endpoint = f"{self.api_url}/generate_playground"
313
+
314
+ # 调用接口
315
+ api_start = time.time()
316
+ logger.info(f"📤 请求接口: {api_endpoint}")
317
+ response = requests.post(
318
+ api_endpoint,
319
+ json=request_data,
320
+ headers={"Content-Type": "application/json"},
321
+ timeout=120
322
+ )
323
+ response.raise_for_status()
324
+ api_time = time.time() - api_start
325
+ logger.info(f"⏱️ API请求耗时: {api_time:.3f}秒")
326
+
327
+ # /generate_playground 返回 JSON
328
+ result_data = response.json()
329
+ logger.info(f"📥 收到响应: code={result_data.get('code')}, message={result_data.get('message')}")
330
+
331
+ if isinstance(result_data, dict) and result_data.get("code") == 3000:
332
+ audio_base64 = result_data.get("data", "")
333
+ if audio_base64:
334
+ decode_start = time.time()
335
+ sample_rate, audio_array = self._base64_to_audio_array(audio_base64)
336
+ decode_time = time.time() - decode_start
337
+ total_time = time.time() - start_time
338
+
339
+ duration = result_data.get("addition", {}).get("duration", "0")
340
+ logger.info(f"📈 性能指标: API={api_time:.3f}s, 解码={decode_time:.3f}s, 总计={total_time:.3f}s, 音频时长={duration}ms")
341
+
342
+ return sample_rate, audio_array
343
+ else:
344
+ raise RuntimeError(f"API返回空音频数据。响应: {result_data}")
345
+ else:
346
+ error_code = result_data.get("code", "unknown")
347
+ error_msg = result_data.get("message", "unknown error")
348
+ logger.error(f"❌ API返回错误: code={error_code}, message={error_msg}")
349
+ raise RuntimeError(f"API错误 [{error_code}]: {error_msg}")
350
 
351
  except requests.exceptions.RequestException as e:
352
+ logger.error(f"❌ API请求失败: {e}")
353
+ raise RuntimeError(f"Failed to connect TTS service: {e}. Check VOXCPM_API_URL='{self.api_url}' and service status")
354
  except Exception as e:
355
+ logger.error(f"❌ API调用异常: {e}")
356
  raise
357
 
358
  def generate_tts_audio(
 
383
  cfg_value = cfg_value_input if cfg_value_input is not None else 2.0
384
  inference_timesteps = inference_timesteps_input if inference_timesteps_input is not None else 10
385
 
386
+ sr, wav_np = self._call_api_generate(
387
  text=text,
388
  prompt_wav_path=prompt_wav_path,
389
  prompt_text=prompt_text,
 
403
 
404
  # ---------- UI Builders ----------
405
 
406
+ def create_demo_interface(client: VoxCPMClient):
407
  """Build the Gradio UI for Gradio API VoxCPM client."""
408
  logger.info("🎨 开始创建Gradio界面...")
409
 
 
444
  """
445
  ) as interface:
446
  gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm-logo.png" alt="VoxCPM Logo"></div>')
447
+
448
+ # Update notice
449
+ gr.Markdown("📢 **12/05: We upgraded the inference model to VoxCPM-1.5.**")
450
 
451
  # Quick Start
452
  with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
 
567
 
568
  try:
569
  # 创建客户端
570
+ logger.info("📡 创建VoxCPM API客户端...")
571
+ client = VoxCPMClient()
572
+ logger.info("✅ VoxCPM API客户端创建成功")
573
 
574
  # 创建界面
575
  logger.info("🎨 创建Gradio界面...")