Translsis commited on
Commit
ac11f96
·
verified ·
1 Parent(s): 901a63a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -146,13 +146,26 @@ def generate_speech(
146
  # Clean text
147
  full_script = text.strip().replace("'", "'").replace('"', '"').replace('"', '"')
148
 
149
- # Get voice sample
150
  voice_sample = VOICE_MAPPER.get_voice_path(speaker_name)
151
 
152
- # Load voice sample to GPU
153
  all_prefilled_outputs = torch.load(
154
- voice_sample, map_location="cuda", weights_only=False
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  progress(0.2, desc="Preparing inputs...")
158
 
@@ -165,8 +178,7 @@ def generate_speech(
165
  return_attention_mask=True,
166
  )
167
 
168
- # Move model and tensors to GPU
169
- MODEL.to("cuda")
170
  for k, v in inputs.items():
171
  if torch.is_tensor(v):
172
  inputs[k] = v.to("cuda")
 
146
  # Clean text
147
  full_script = text.strip().replace("'", "'").replace('"', '"').replace('"', '"')
148
 
149
+ # Get voice sample path
150
  voice_sample = VOICE_MAPPER.get_voice_path(speaker_name)
151
 
152
+ # Load voice sample to CPU first, then move to GPU
153
  all_prefilled_outputs = torch.load(
154
+ voice_sample, map_location="cpu", weights_only=False
155
  )
156
+
157
+ # Move model to GPU
158
+ MODEL.to("cuda")
159
+
160
+ # Move voice sample tensors to GPU
161
+ if isinstance(all_prefilled_outputs, dict):
162
+ for key in all_prefilled_outputs:
163
+ if torch.is_tensor(all_prefilled_outputs[key]):
164
+ all_prefilled_outputs[key] = all_prefilled_outputs[key].to("cuda")
165
+ elif isinstance(all_prefilled_outputs[key], dict):
166
+ for sub_key in all_prefilled_outputs[key]:
167
+ if torch.is_tensor(all_prefilled_outputs[key][sub_key]):
168
+ all_prefilled_outputs[key][sub_key] = all_prefilled_outputs[key][sub_key].to("cuda")
169
 
170
  progress(0.2, desc="Preparing inputs...")
171
 
 
178
  return_attention_mask=True,
179
  )
180
 
181
+ # Move input tensors to GPU
 
182
  for k, v in inputs.items():
183
  if torch.is_tensor(v):
184
  inputs[k] = v.to("cuda")