import nltk nltk.download('punkt_tab') # <--- Thêm dòng này nltk.download('punkt') # <--- Nên thêm cả dòng này cho chắc chắn import torch torch.manual_seed(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True import random random.seed(0) import numpy as np np.random.seed(0) import time import random import yaml from munch import Munch import numpy as np import torch from torch import nn import torch.nn.functional as F import torchaudio import librosa from nltk.tokenize import word_tokenize from models import * from utils import * from text_utils import TextCleaner import soundfile as sf import os textclenaer = TextCleaner() to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300) mean, std = -4, 4 def length_to_mask(lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask+1, lengths.unsqueeze(1)) return mask def preprocess(wave): wave_tensor = torch.from_numpy(wave).float() mel_tensor = to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std return mel_tensor def compute_style(path): wave, sr = librosa.load(path, sr=24000) audio, index = librosa.effects.trim(wave, top_db=30) if sr != 24000: audio = librosa.resample(audio, sr, 24000) mel_tensor = preprocess(audio).to(device) with torch.no_grad(): ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) return torch.cat([ref_s, ref_p], dim=1) device = 'cuda' if torch.cuda.is_available() else 'cpu' # load phonemizer import phonemizer global_phonemizer = phonemizer.backend.EspeakBackend(language='vi', preserve_punctuation=True, with_stress=True,language_switch="remove-flags") # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True,language_switch="remove-flags") config = yaml.safe_load(open("./Configs/config_libritts.yml")) # config = yaml.safe_load(open("./Configs/config_vokan.yml")) # load pretrained ASR model ASR_config = config.get('ASR_config', False) ASR_path = config.get('ASR_path', False) text_aligner = load_ASR_models(ASR_path, ASR_config) # load pretrained F0 model F0_path = config.get('F0_path', False) pitch_extractor = load_F0_models(F0_path) # load BERT model from Utils.PLBERT.util import load_plbert BERT_path = config.get('PLBERT_dir', False) plbert = load_plbert(BERT_path) model_params = recursive_munch(config['model_params']) model = build_model(model_params, text_aligner, pitch_extractor, plbert) _ = [model[key].eval() for key in model] _ = [model[key].to(device) for key in model] # params_whole = torch.load("/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/model_40speaker/model_iter_00004000.pth", map_location='cpu') print("Loading pretrained model from HF...") # params_whole = torch.load("/u01/colombo/hungnt/hieuld/TTS_clone/pretrainedModel/hieuducle/train_second_bestcheckpoint/best_model.pth", map_location='cpu') params_whole = torch.load("/workspace/StyleTTS2/Models/LibriTTS/model_iter_00002300.pth", map_location='cpu') params = params_whole['net'] for key in model: if key in params: print('%s loaded' % key) try: model[key].load_state_dict(params[key]) except: from collections import OrderedDict state_dict = params[key] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model[key].load_state_dict(new_state_dict, strict=False) # except: # _load(params[key], model[key]) _ = [model[key].eval() for key in model] from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule sampler = DiffusionSampler( model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters clamp=False ) passage = '''Lý do ông Putin chưa chấp nhận kế hoạch hòa bình từ Mỹ, Tổng thống Putin không vội chấp nhận đề xuất hòa bình Ukraine của Mỹ, khi Nga có lợi thế đàm phán nhờ đà tiến trên chiến trường, trong khi phương Tây lục đục nội bộ. Chính quyền Tổng thống Mỹ Donald Trump hồi tháng 11 xây dựng kế hoạch hòa bình 28 điểm để chấm dứt chiến sự Nga - Ukraine, sau đó tiến hành chiến dịch ngoại giao con thoi để thuyết phục hai bên chấp nhận. Ukraine cùng châu Âu đã nhanh chóng phản đối kế hoạch, do nó có nhiều đề xuất hoàn toàn có lợi cho Nga''' # passage = '''Kumapaterado là ai và làm gì''' def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1): text = text.strip() ps = global_phonemizer.phonemize([text]) ps = word_tokenize(ps[0]) ps = ' '.join(ps) ps = ps.replace('``', '"') ps = ps.replace("''", '"') ps = ps.replace('t̪', '\uFFFF').replace('t', 'tʰ').replace('\uFFFF', 't') tokens = textclenaer(ps) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), embedding=bert_dur, embedding_scale=embedding_scale, features=ref_s, # reference from the same speaker as the embedding num_steps=diffusion_steps).squeeze(1) if s_prev is not None: # convex combination of previous and current style s_pred = t * s_prev + (1 - t) * s_pred s = s_pred[:, 128:] ref = s_pred[:, :128] ref = alpha * ref + (1 - alpha) * ref_s[:, :128] s = beta * s + (1 - beta) * ref_s[:, 128:] s_pred = torch.cat([ref, s], dim=-1) d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = model.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # encode prosody en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) if model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(en) asr_new[:, :, 0] = en[:, :, 0] asr_new[:, :, 1:] = en[:, :, 0:-1] en = asr_new F0_pred, N_pred = model.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) if model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(asr) asr_new[:, :, 0] = asr[:, :, 0] asr_new[:, :, 1:] = asr[:, :, 0:-1] asr = asr_new out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later # os.makedirs("outputs", exist_ok=True) # unseen speaker path = "/workspace/StyleTTS2/audio_ref/sangnq.wav" s_ref = compute_style(path) sentences = passage.split('.') # simple split by comma wavs = [] s_prev = None print("Synthesizing...") for text in sentences: if text.strip() == "": continue text += '.' # add it back wav, s_prev = LFinference(text, s_prev, s_ref, # alpha = 0.3, alpha = 0, # beta = 0.7, beta = 0, # make it more suitable for the text t = 0.7, diffusion_steps=5, embedding_scale=1) wavs.append(wav) # concat all segments # final_wav = np.concatenate(wavs) # name_audio = os.path.basename(path).split('.')[0] # out_path = f"./audio_clone/{name_audio}_clone.wav" # sf.write(out_path, final_wav, 24000, subtype="PCM_16") # print("Saved synthesized audio to:", out_path) # print("Reference audio:", path) # concat all segments final_wav = np.concatenate(wavs) name_audio = os.path.basename(path).split('.')[0] # 1. Đổi tên file output thành .mp3 out_path = f"./audio_clone/{name_audio}_clone.mp3" # 2. Dùng torchaudio để lưu mp3 (thay vì sf.write) # final_wav đang là numpy, cần chuyển sang Tensor và thêm chiều channel (1, Time) torchaudio.save(out_path, torch.from_numpy(final_wav).unsqueeze(0), 24000, format="mp3") print("Saved synthesized audio to:", out_path) print("Reference audio:", path)