Update generation.py
Browse files- generation.py +3 -2
generation.py
CHANGED
|
@@ -1197,8 +1197,9 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1197 |
self.vad_seek_callback(kwargs["stno_mask"])
|
| 1198 |
if "is_valid" in kwargs:
|
| 1199 |
kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
|
| 1200 |
-
|
| 1201 |
-
|
|
|
|
| 1202 |
return kwargs
|
| 1203 |
|
| 1204 |
def generate_with_fallback(
|
|
|
|
| 1197 |
self.vad_seek_callback(kwargs["stno_mask"])
|
| 1198 |
if "is_valid" in kwargs:
|
| 1199 |
kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
|
| 1200 |
+
if "labels" in kwargs:
|
| 1201 |
+
kwargs['labels'] = kwargs["labels"][batch_idx_map]
|
| 1202 |
+
kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
|
| 1203 |
return kwargs
|
| 1204 |
|
| 1205 |
def generate_with_fallback(
|