| | from typing import Union |
| |
|
| | import numpy as np |
| | from transformers.utils import TensorType |
| | from transformers.feature_extraction_utils import BatchFeature |
| |
|
| |
|
| | class PadAndSortCollator: |
| | def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"): |
| | self.processor = processor |
| | self.return_tensors = return_tensors |
| |
|
| | def __call__(self, batch): |
| | """ |
| | expect batch with `return_tensors=None` from processor |
| | batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional) |
| | """ |
| | text_batch = {} |
| | text_batch["input_ids"] = [x["input_ids"] for x in batch] |
| | if "length" in batch[0]: |
| | text_batch["length"] = [x["length"] for x in batch] |
| | else: |
| | text_batch["length"] = [len(x["input_ids"]) for x in batch] |
| |
|
| | audio_batch = {} |
| | |
| | audio_batch["mel_specgram"] = [ |
| | x["mel_specgram"][0].transpose(1, 0) for x in batch |
| | ] |
| | if "mel_specgram_length" in batch[0]: |
| | audio_batch["mel_specgram_length"] = [ |
| | x["mel_specgram_length"] for x in batch |
| | ] |
| | else: |
| | audio_batch["mel_specgram_length"] = [ |
| | x["mel_specgram"][0].shape[1] for x in batch |
| | ] |
| |
|
| | text_batch = self.processor.tokenizer.pad( |
| | text_batch, |
| | padding=True, |
| | return_tensors="np", |
| | return_attention_mask=False, |
| | ) |
| |
|
| | audio_batch = self.processor.feature_extractor.pad( |
| | audio_batch, |
| | padding=True, |
| | return_tensors="np", |
| | return_attention_mask=True, |
| | ) |
| | audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1) |
| |
|
| | attention_mask = audio_batch.pop("attention_mask") |
| | gate_padded = 1 - attention_mask |
| | gate_padded = np.roll(gate_padded, -1, axis=1) |
| | gate_padded[:, -1] = 1 |
| | gate_padded = gate_padded.astype(np.float32) |
| |
|
| | output = {**text_batch, **audio_batch, "gate_padded": gate_padded} |
| |
|
| | |
| | sort_idx = np.argsort(output["length"])[::-1] |
| |
|
| | for key, value in output.items(): |
| | output[key] = value[sort_idx] |
| |
|
| | return BatchFeature(output, tensor_type=self.return_tensors) |
| |
|