| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing. |
| | """ |
| |
|
| | import os |
| | import re |
| | import torch |
| | import numpy as np |
| | import soundfile as sf |
| | import soxr |
| |
|
| | from pathlib import Path |
| | from typing import Optional, Union, List, Dict, Tuple, Any |
| |
|
| | from transformers.processing_utils import ProcessorMixin |
| | from transformers.tokenization_utils_base import BatchEncoding |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.models.auto.tokenization_auto import AutoTokenizer |
| | from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor |
| | from transformers.utils import logging, PushToHubMixin |
| | from numpy.lib.stride_tricks import sliding_window_view |
| | import soxr |
| | import soundfile |
| | import random |
| |
|
| | |
| | from .configuration_spark_tts import SparkTTSConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: |
| | """ |
| | Normalize the volume of an audio signal. |
| | |
| | Parameters: |
| | audio (numpy array): Input audio signal array. |
| | coeff (float): Target coefficient for normalization, default is 0.2. |
| | |
| | Returns: |
| | numpy array: The volume-normalized audio signal. |
| | """ |
| | |
| | temp = np.sort(np.abs(audio)) |
| |
|
| | |
| | if temp[-1] < 0.1: |
| | scaling_factor = max( |
| | temp[-1], 1e-3 |
| | ) |
| | audio = audio / scaling_factor * 0.1 |
| |
|
| | |
| | temp = temp[temp > 0.01] |
| | L = temp.shape[0] |
| |
|
| | |
| | if L <= 10: |
| | return audio |
| |
|
| | |
| | volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) |
| |
|
| | |
| | audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) |
| |
|
| | |
| | max_value = np.max(np.abs(audio)) |
| | if max_value > 1: |
| | audio = audio / max_value |
| |
|
| | return audio |
| |
|
| |
|
| | def load_audio( |
| | adfile: Path, |
| | sampling_rate: int = None, |
| | length: int = None, |
| | volume_normalize: bool = False, |
| | segment_duration: int = None, |
| | ) -> np.ndarray: |
| | r"""Load audio file with target sampling rate and lsength |
| | |
| | Args: |
| | adfile (Path): path to audio file. |
| | sampling_rate (int, optional): target sampling rate. Defaults to None. |
| | length (int, optional): target audio length. Defaults to None. |
| | volume_normalize (bool, optional): whether perform volume normalization. Defaults to False. |
| | segment_duration (int): random select a segment with duration of {segment_duration}s. |
| | Defualt to None which means the whole audio will be used. |
| | |
| | Returns: |
| | audio (np.ndarray): audio |
| | """ |
| |
|
| | audio, sr = soundfile.read(adfile) |
| | if len(audio.shape) > 1: |
| | audio = audio[:, 0] |
| |
|
| | if sampling_rate is not None and sr != sampling_rate: |
| | audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") |
| | sr = sampling_rate |
| |
|
| | if segment_duration is not None: |
| | seg_length = int(sr * segment_duration) |
| | audio = random_select_audio_segment(audio, seg_length) |
| |
|
| | |
| | if volume_normalize: |
| | audio = audio_volume_normalize(audio) |
| | |
| | if length is not None: |
| | assert abs(audio.shape[0] - length) < 1000 |
| | if audio.shape[0] > length: |
| | audio = audio[:length] |
| | else: |
| | audio = np.pad(audio, (0, int(length - audio.shape[0]))) |
| | return audio |
| |
|
| |
|
| | def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: |
| | """get an audio segment given the length |
| | |
| | Args: |
| | audio (np.ndarray): |
| | length (int): audio length = sampling_rate * duration |
| | """ |
| | if audio.shape[0] < length: |
| | audio = np.pad(audio, (0, int(length - audio.shape[0]))) |
| | start_index = random.randint(0, audio.shape[0] - length) |
| | end_index = int(start_index + length) |
| |
|
| | return audio[start_index:end_index] |
| |
|
| | def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: |
| | """Get reference audio clip for speaker embedding.""" |
| | |
| | if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']): |
| | raise AttributeError("Config object missing required attributes for get_ref_clip") |
| | ref_segment_length = ( |
| | int(config.sample_rate * config.ref_segment_duration) |
| | // config.latent_hop_length |
| | * config.latent_hop_length |
| | ) |
| | wav_length = len(wav) |
| | if ref_segment_length > wav_length: |
| | wav = np.tile(wav, ref_segment_length // wav_length + 1) |
| | return wav[:ref_segment_length] |
| |
|
| |
|
| | |
| |
|
| | TASK_TOKEN_MAP = { |
| | "vc": "<|task_vc|>", |
| | "tts": "<|task_tts|>", |
| | "asr": "<|task_asr|>", |
| | "s2s": "<|task_s2s|>", |
| | "t2s": "<|task_t2s|>", |
| | "understand": "<|task_understand|>", |
| | "caption": "<|task_cap|>", |
| | "controllable_tts": "<|task_controllable_tts|>", |
| | "prompt_tts": "<|task_prompt_tts|>", |
| | "speech_edit": "<|task_edit|>", |
| | } |
| |
|
| | LEVELS_MAP = { |
| | "very_low": 0, |
| | "low": 1, |
| | "moderate": 2, |
| | "high": 3, |
| | "very_high": 4, |
| | } |
| |
|
| | LEVELS_MAP_UI = { |
| | 1: 'very_low', |
| | 2: 'low', |
| | 3: 'moderate', |
| | 4: 'high', |
| | 5: 'very_high' |
| | } |
| |
|
| | GENDER_MAP = { |
| | "female": 0, |
| | "male": 1, |
| | } |
| |
|
| | AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4} |
| |
|
| | EMO_MAP = { |
| | "UNKNOWN": 0, |
| | "NEUTRAL": 1, |
| | "ANGRY": 2, |
| | "HAPPY": 3, |
| | "SAD": 4, |
| | "FEARFUL": 5, |
| | "DISGUSTED": 6, |
| | "SURPRISED": 7, |
| | "SARCASTIC": 8, |
| | "EXCITED": 9, |
| | "SLEEPY": 10, |
| | "CONFUSED": 11, |
| | "EMPHASIS": 12, |
| | "LAUGHING": 13, |
| | "SINGING": 14, |
| | "WORRIED": 15, |
| | "WHISPER": 16, |
| | "ANXIOUS": 17, |
| | "NO-AGREEMENT": 18, |
| | "APOLOGETIC": 19, |
| | "CONCERNED": 20, |
| | "ENUNCIATED": 21, |
| | "ASSERTIVE": 22, |
| | "ENCOURAGING": 23, |
| | "CONTEMPT": 24, |
| | } |
| |
|
| |
|
| | class TokenParser: |
| | """Turn label to special token""" |
| |
|
| | def __init__(self): |
| | pass |
| |
|
| | """Parse the attributes of a person.""" |
| |
|
| | def __init__(self): |
| | pass |
| |
|
| | @staticmethod |
| | def age(age: str) -> str: |
| | """Turn age token.""" |
| | age_id = AGE_MAP[age] |
| | return f"<|age_{age_id}|>" |
| |
|
| | @staticmethod |
| | def gender(gender: str) -> str: |
| | """Turn gender token.""" |
| | gender_id = GENDER_MAP[gender] |
| | return f"<|gender_{gender_id}|>" |
| |
|
| | @staticmethod |
| | def mel_value(mel: int): |
| | """Turn special token of mel scale pitch.""" |
| | mel = max(0, int(mel)) |
| | mel = min(1000, int(mel)) |
| | return f"<|pitch_value_{mel}|>" |
| |
|
| | @staticmethod |
| | def mel_level(level: str): |
| | """Turn special token of mel level.""" |
| | level_tag = LEVELS_MAP[level] |
| | return f"<|pitch_label_{level_tag}|>" |
| |
|
| | @staticmethod |
| | def pitch_var_value(pitch_std: int): |
| | """Turn special token of pitch_std value.""" |
| | assert isinstance(pitch_std, int) |
| | pitch_std = max(0, int(pitch_std)) |
| | pitch_std = min(10, int(pitch_std)) |
| | return f"<|pitch_var_value_{pitch_std}|>" |
| |
|
| | @staticmethod |
| | def pitch_var_level(level: str): |
| | """Turn special token of pitch std level.""" |
| | level_tag = LEVELS_MAP[level] |
| | return f"<|pitch_var_label_{level_tag}|>" |
| |
|
| | @staticmethod |
| | def loudness_value(loudness: int): |
| | """Turn special toak of loudness value [0, 30]""" |
| | assert loudness >= 0 |
| | loudness = max(0, int(loudness)) |
| | loudness = min(30, int(loudness)) |
| | return f"<|loudness_value_{loudness}|>" |
| |
|
| | @staticmethod |
| | def loudness_level(level: str): |
| | """Turn special token of loudness level.""" |
| | level_tag = LEVELS_MAP[level] |
| | return f"<|loudness_label_{level_tag}|>" |
| |
|
| | @staticmethod |
| | def speed_value(speed: int): |
| | """Turn special token of speed value.""" |
| | speed = max(0, int(speed)) |
| | speed = min(10, int(speed)) |
| | return f"<|speed_value_{speed}|>" |
| |
|
| | @staticmethod |
| | def speed_level(level: str): |
| | """Turn special token of speed level.""" |
| | level_tag = LEVELS_MAP[level] |
| | return f"<|speed_label_{level_tag}|>" |
| |
|
| | @staticmethod |
| | def task(task: str) -> str: |
| | """Turn special token of task.""" |
| | assert task in TASK_TOKEN_MAP.keys() |
| |
|
| | return TASK_TOKEN_MAP[task] |
| |
|
| | @staticmethod |
| | def emotion(emotion: str): |
| | emo_id = EMO_MAP[emotion] |
| |
|
| | return f"<|emotion_{emo_id}|>" |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): |
| | r""" |
| | Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic. |
| | |
| | Args: |
| | tokenizer ([`PreTrainedTokenizer`]): |
| | An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM. |
| | feature_extractor ([`Wav2Vec2FeatureExtractor`]): |
| | An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted |
| | within the model's `tokenize_audio`, the extractor's configuration (like sampling rate) |
| | is useful, and it aligns with the ProcessorMixin pattern. |
| | config ([`SparkTTSConfig`], *optional*): |
| | An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate. |
| | """ |
| | attributes = ["tokenizer", "feature_extractor"] |
| | tokenizer_class = "AutoTokenizer" |
| | feature_extractor_class = "Wav2Vec2FeatureExtractor" |
| |
|
| | def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs): |
| | super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs) |
| | self.model = None |
| | self.config = config |
| | |
| | if config and hasattr(config, 'sample_rate'): |
| | self.sampling_rate = config.sample_rate |
| | elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'): |
| | self.sampling_rate = feature_extractor.sampling_rate |
| | else: |
| | self.sampling_rate = 16000 |
| | logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def link_model(self, model): |
| | """Links the processor to a SparkTTSModel instance for audio processing calls.""" |
| | if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'): |
| | raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.") |
| | if not hasattr(model, 'config'): |
| | logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.") |
| |
|
| | self.model = model |
| | logger.info("SparkTTSModel successfully linked to the processor.") |
| | |
| | if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'): |
| | if self.sampling_rate != model.config.sample_rate: |
| | logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.") |
| | self.sampling_rate = model.config.sample_rate |
| | |
| | if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate: |
| | logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.") |
| | self.feature_extractor.sampling_rate = model.config.sample_rate |
| |
|
| |
|
| | def __call__( |
| | self, |
| | text: str, |
| | prompt_speech_path: Optional[Union[str, Path]] = None, |
| | prompt_text: Optional[str] = None, |
| | gender: Optional[str] = None, |
| | pitch: Optional[str] = None, |
| | speed: Optional[str] = None, |
| | return_tensors: Optional[str] = "pt", |
| | **kwargs, |
| | ) -> BatchEncoding: |
| | """ |
| | Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`]. |
| | |
| | Args: |
| | text (`str`): |
| | The main text to be synthesized. |
| | prompt_speech_path (`str` or `Path`, *optional*): |
| | Path to the prompt audio file for voice cloning. Required if `gender` is not set. |
| | prompt_text (`str`, *optional*): |
| | Transcript of the prompt audio. Used only in voice cloning mode. |
| | gender (`str`, *optional*): |
| | Target gender ("male" or "female") for controllable synthesis. If set, enables control mode. |
| | pitch (`str`, *optional*): |
| | Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set. |
| | speed (`str`, *optional*): |
| | Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set. |
| | return_tensors (`str`, *optional*, defaults to `"pt"`): |
| | If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently. |
| | **kwargs: |
| | Additional arguments passed to the underlying tokenizer's `__call__` method. |
| | |
| | Returns: |
| | [`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM. |
| | In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the |
| | global tokens extracted from the prompt audio. |
| | """ |
| |
|
| | global_token_ids_prompt = None |
| |
|
| | |
| | is_control_mode = gender is not None |
| | is_cloning_mode = prompt_speech_path is not None and not is_control_mode |
| |
|
| | if is_control_mode: |
| | |
| | if not all([pitch, speed]): |
| | raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.") |
| | if prompt_speech_path is not None: |
| | logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).") |
| |
|
| | if not all(k in GENDER_MAP for k in [gender]): |
| | raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}") |
| | if not all(k in LEVELS_MAP for k in [pitch, speed]): |
| | raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}") |
| |
|
| | gender_id = GENDER_MAP[gender] |
| | pitch_level_id = LEVELS_MAP[pitch] |
| | speed_level_id = LEVELS_MAP[speed] |
| |
|
| | pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" |
| | speed_label_tokens = f"<|speed_label_{speed_level_id}|>" |
| | gender_tokens = f"<|gender_{gender_id}|>" |
| |
|
| | attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens]) |
| |
|
| | prompt_list = [ |
| | TASK_TOKEN_MAP["controllable_tts"], |
| | "<|start_content|>", |
| | text, |
| | "<|end_content|>", |
| | "<|start_style_label|>", |
| | attribute_tokens, |
| | "<|end_style_label|>", |
| | ] |
| | prompt_string = "".join(prompt_list) |
| |
|
| | elif is_cloning_mode: |
| | |
| | if self.model is None: |
| | raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.") |
| | prompt_speech_path = Path(prompt_speech_path) |
| | if not prompt_speech_path.exists(): |
| | raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}") |
| |
|
| | |
| | try: |
| | model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config |
| | if model_config is None: |
| | raise ValueError("Configuration not available in processor or linked model.") |
| |
|
| | |
| | wav = load_audio( |
| | prompt_speech_path, |
| | sampling_rate=self.sampling_rate, |
| | volume_normalize=getattr(model_config, 'volume_normalize', True), |
| | ) |
| | |
| | wav_ref_np = get_ref_clip(wav, model_config) |
| | wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float() |
| | wav_tensor = torch.from_numpy(wav).unsqueeze(0).float() |
| |
|
| | |
| | |
| | global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref) |
| |
|
| | |
| | global_token_ids_prompt = global_tokens_tensor |
| |
|
| | |
| | global_token_list = global_tokens_tensor.squeeze().tolist() |
| | semantic_token_list = semantic_tokens_tensor.squeeze().tolist() |
| |
|
| | except Exception as e: |
| | logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |
| |
|
| | |
| | |
| | |
| | |
| | global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list]) |
| | semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list]) |
| | |
| | |
| | |
| | if prompt_text is not None and prompt_text.strip(): |
| | logger.info("Using prompt text in voice cloning prompt.") |
| | prompt_list = [ |
| | TASK_TOKEN_MAP["tts"], |
| | "<|start_content|>", |
| | prompt_text, |
| | text, |
| | "<|end_content|>", |
| | "<|start_global_token|>", |
| | global_tokens_str, |
| | "<|end_global_token|>", |
| | "<|start_semantic_token|>", |
| | semantic_tokens_str, |
| | |
| | ] |
| | else: |
| | |
| | logger.info("No prompt text provided, using text-only voice cloning prompt.") |
| | prompt_list = [ |
| | TASK_TOKEN_MAP["tts"], |
| | "<|start_content|>", |
| | text, |
| | "<|end_content|>", |
| | "<|start_global_token|>", |
| | global_tokens_str, |
| | "<|end_global_token|>", |
| | ] |
| | prompt_string = "".join(prompt_list) |
| | logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") |
| |
|
| | else: |
| | raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.") |
| |
|
| | |
| | |
| | inputs = self.tokenizer( |
| | prompt_string, |
| | return_tensors=return_tensors, |
| | padding=kwargs.get("padding", False), |
| | truncation=kwargs.get("truncation", True), |
| | max_length=kwargs.get("max_length", self.tokenizer.model_max_length), |
| | add_special_tokens=kwargs.get("add_special_tokens", True), |
| | return_attention_mask=kwargs.get("return_attention_mask", True), |
| | **{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]} |
| | ) |
| | logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}") |
| |
|
| |
|
| | |
| | if is_cloning_mode and global_token_ids_prompt is not None: |
| | if return_tensors == "pt": |
| | inputs["global_token_ids_prompt"] = global_token_ids_prompt |
| | else: |
| | |
| | inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist() |
| |
|
| | return inputs |
| |
|
| |
|
| | def decode( |
| | self, |
| | generated_ids: torch.Tensor, |
| | global_token_ids_prompt: Optional[torch.Tensor] = None, |
| | input_ids_len: Optional[int] = None, |
| | skip_special_tokens: bool = True, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform. |
| | |
| | Args: |
| | generated_ids (`torch.Tensor`): |
| | Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len]. |
| | global_token_ids_prompt (`torch.Tensor`, *optional*): |
| | The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning). |
| | Shape [B, N_global]. Required if the generation was for voice cloning. |
| | input_ids_len (`int`, *optional*): |
| | The length of the original input prompt `input_ids` fed to `model.generate()`. Required to |
| | correctly isolate the newly generated tokens. |
| | skip_special_tokens (`bool`, *optional*, defaults to `True`): |
| | Whether to skip special tokens during the text decoding step (used to extract audio tokens). |
| | |
| | Returns: |
| | Dict[str, Any]: A dictionary containing: |
| | - "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio]. |
| | - "sampling_rate": The sampling rate of the audio. |
| | """ |
| | if self.model is None: |
| | raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.") |
| | if input_ids_len is None: |
| | raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.") |
| |
|
| | |
| | |
| | |
| | if generated_ids.shape[1] < input_ids_len: |
| | logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.") |
| | output_only_ids = generated_ids[:, input_ids_len:] |
| | else: |
| | output_only_ids = generated_ids[:, input_ids_len:] |
| |
|
| |
|
| | |
| | |
| | |
| | decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens) |
| |
|
| | |
| | |
| | batch_size = generated_ids.shape[0] |
| | all_semantic_ids = [] |
| | all_global_tokens = [] |
| | successful_indices = [] |
| |
|
| | for i in range(batch_size): |
| | decoded_text = decoded_texts[i] |
| | current_semantic_ids = None |
| | current_global_tokens = None |
| |
|
| | |
| | try: |
| | pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)] |
| | if not pred_semantic_indices: |
| | logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'") |
| | continue |
| |
|
| | current_semantic_ids = torch.tensor(pred_semantic_indices).long() |
| | except Exception as e: |
| | logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}") |
| | continue |
| |
|
| | |
| | if global_token_ids_prompt is not None: |
| | |
| | if global_token_ids_prompt.shape[0] != batch_size: |
| | raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.") |
| | current_global_tokens = global_token_ids_prompt[i] |
| | else: |
| | |
| | try: |
| | pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)] |
| | if not pred_global_indices: |
| | logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'") |
| | continue |
| |
|
| | current_global_tokens = torch.tensor(pred_global_indices).long() |
| |
|
| | except Exception as e: |
| | logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}") |
| | continue |
| |
|
| | |
| | all_semantic_ids.append(current_semantic_ids) |
| | all_global_tokens.append(current_global_tokens) |
| | successful_indices.append(i) |
| |
|
| | if not successful_indices: |
| | logger.error("Failed to extract audio tokens for any item in the batch.") |
| | return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if batch_size > 1 and len(successful_indices) < batch_size: |
| | logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.") |
| | |
| |
|
| | |
| | |
| | try: |
| | |
| | |
| | |
| | if len(successful_indices) != 1: |
| | raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.") |
| |
|
| | final_semantic_ids = all_semantic_ids[0].unsqueeze(0) |
| | final_global_tokens = all_global_tokens[0].unsqueeze(0) |
| |
|
| | except IndexError: |
| | logger.error("Internal error during token batch preparation.") |
| | return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
| |
|
| |
|
| | |
| | try: |
| | |
| | |
| | output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids) |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | except Exception as e: |
| | logger.error(f"Error during audio detokenization: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | raise RuntimeError("Audio detokenization failed.") from e |
| |
|
| | return {"audio": output_wav, "sampling_rate": self.sampling_rate} |
| |
|
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Union[str, os.PathLike], |
| | cache_dir: Optional[Union[str, os.PathLike]] = None, |
| | force_download: bool = False, |
| | local_files_only: bool = False, |
| | token: Optional[Union[str, bool]] = None, |
| | revision: str = "main", |
| | trust_remote_code: bool = False, |
| | **kwargs, |
| | ): |
| | r""" |
| | Instantiate a SparkTTSProcessor from pretrained components. |
| | """ |
| | |
| | config = kwargs.pop("config", None) |
| |
|
| | |
| | |
| | |
| | loaded_config = None |
| | if not isinstance(config, SparkTTSConfig): |
| | try: |
| | |
| | loaded_config = SparkTTSConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | cache_dir=cache_dir, |
| | force_download=force_download, |
| | local_files_only=local_files_only, |
| | token=token, |
| | revision=revision, |
| | trust_remote_code=trust_remote_code, |
| | **kwargs, |
| | ) |
| | except Exception as e: |
| | logger.warning( |
| | f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. " |
| | f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}" |
| | ) |
| | loaded_config = None |
| | else: |
| | |
| | loaded_config = config |
| |
|
| |
|
| | |
| | llm_tokenizer_path_or_id = "./LLM" |
| | w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" |
| |
|
| | if loaded_config: |
| | llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id) |
| | w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | component_loading_kwargs = { |
| | "cache_dir": cache_dir, |
| | "force_download": force_download, |
| | "local_files_only": local_files_only, |
| | "token": token, |
| | "revision": revision, |
| | **kwargs |
| | } |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | pretrained_model_name_or_path, |
| | subfolder=llm_tokenizer_path_or_id.lstrip('./'), |
| | trust_remote_code=trust_remote_code, |
| | **component_loading_kwargs |
| | ) |
| | except Exception as e: |
| | |
| | if llm_tokenizer_path_or_id != "./LLM": |
| | try: |
| | logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | llm_tokenizer_path_or_id, |
| | trust_remote_code=trust_remote_code, |
| | **component_loading_kwargs |
| | ) |
| | except Exception as e2: |
| | raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e |
| | else: |
| | raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}") |
| |
|
| |
|
| | try: |
| | |
| | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
| | pretrained_model_name_or_path, |
| | subfolder=w2v_processor_path_or_id.lstrip('./'), |
| | **component_loading_kwargs |
| | ) |
| | except Exception as e: |
| | |
| | if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53": |
| | try: |
| | logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}") |
| | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
| | w2v_processor_path_or_id, |
| | **component_loading_kwargs |
| | ) |
| | except Exception as e2: |
| | raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e |
| | else: |
| | raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}") |
| |
|
| |
|
| | |
| | |
| | return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config) |
| |
|
| | |
| | def save_pretrained( |
| | self, |
| | save_directory: Union[str, os.PathLike], |
| | push_to_hub: bool = False, |
| | **kwargs, |
| | ): |
| | """ |
| | Save the processor's state (tokenizer and feature extractor files) to a directory. |
| | |
| | Args: |
| | save_directory (`str` or `os.PathLike`): |
| | Directory where the processor files will be saved. |
| | push_to_hub (`bool`, *optional*, defaults to `False`): |
| | Whether or not to push your model to the Hugging Face Hub after saving it. |
| | **kwargs: |
| | Additional key word arguments passed along to the `push_to_hub` method. |
| | """ |
| | save_directory = Path(save_directory) |
| | save_directory.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | self.tokenizer.save_pretrained(str(save_directory), **kwargs) |
| |
|
| | |
| | self.feature_extractor.save_pretrained(str(save_directory), **kwargs) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | logger.info(f"Processor components saved in {save_directory}") |
| |
|
| | if push_to_hub: |
| | |
| | commit_message = kwargs.pop("commit_message", "Save processor") |
| | return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs) |
| |
|
| | return str(save_directory) |