""" Shared Preprocessing Module for Arabic Text Punctuation Prediction This module contains all shared functionality for data preprocessing, label extraction, and dataset creation. Authors: Mohammed{Alkhuzanie, Fael} Date: December 2025 """ import os import re import gc import pickle import zipfile from collections import Counter, defaultdict from typing import List, Tuple, Dict, Optional import numpy as np import pandas as pd from tqdm import tqdm import torch import torch.nn as nn import glob import torch.nn.functional as F from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence from huggingface_hub import hf_hub_download # ============================================================================ # CONSTANTS AND MAPPINGS # ============================================================================ # Punctuation mark labels PUNCTUATION_MAP = { 'NO_PUNCT': 0, 'QUESTION': 1, # ? 'COMMA': 2, # ، 'COLON': 3, # : 'SEMICOLON': 4, # ؛ 'EXCLAMATION': 5, # ! 'PERIOD': 6 # . } # Reverse mapping for label to punctuation LABEL_TO_PUNCT = { 0: '', 1: '؟', 2: '،', 3: ':', 4: '؛', 5: '!', 6: '.' } PUNCTUATION_CHARS = { '?': 1, # Question mark (normalized from ؟) '،': 2, # Arabic comma (normalized from ,) ':': 3, # Colon '؛': 4, # Arabic semicolon (normalized from ;) '!': 5, # Exclamation mark '.': 6 # Period (normalized from \u06d4) } # Random seed for reproducibility SEED = 42 # ============================================================================ # DATA LOADING FUNCTIONS # ============================================================================ def load_text_files(data_dir: str) -> List[str]: """ Load all text files from the data directory. Args: data_dir: Path to data directory Returns: List of text strings """ texts = [] # Look for .txt files recursively for root, dirs, files in os.walk(data_dir): for file in files: if file.endswith('.txt'): file_path = os.path.join(root, file) try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() if content.strip(): # Only add non-empty content texts.append(content) except Exception as e: print(f"Error reading {file_path}: {e}") return texts # ============================================================================ # TEXT PREPROCESSING FUNCTIONS # ============================================================================ def remove_diacritics(text: str) -> str: """ Remove Arabic diacritics (tashkeel) from text. Arabic diacritics are marks that indicate vowel sounds. Removing them is common in NLP to reduce vocabulary size. Args: text: Input Arabic text Returns: Text without diacritics """ # Arabic diacritics Unicode range arabic_diacritics = re.compile(r'[\u0617-\u061A\u064B-\u0652]') return arabic_diacritics.sub('', text) def normalize_arabic_text(text: str) -> str: """ Normalize Arabic characters to standard forms. Normalizations: - أ، إ، آ → ا (normalize alef variants) - ة → ه (taa marbuta to haa) Args: text: Input Arabic text Returns: Normalized text """ text = re.sub('[إأآا]', 'ا', text) text = re.sub('ة', 'ه', text) return text def normalize_punctuation(text: str) -> str: """ Normalize punctuation marks to their canonical forms. This ensures consistency across different Unicode representations of the same punctuation mark. Normalizations: - ؟ (Arabic question mark U+061F) → ? (English question mark) - , (English comma) → ، (Arabic comma) - ; (English semicolon) → ؛ (Arabic semicolon) - Keep periods, colons, and exclamation marks as-is Args: text: Input text Returns: Text with normalized punctuation """ # Normalize question marks to English ? text = text.replace('؟', '?') # Normalize commas to Arabic ، text = text.replace(',', '،') # Normalize semicolons to Arabic ؛ text = text.replace(';', '؛') # Normalize periods (remove Arabic-specific period variants) text = text.replace('\u06d4', '.') # Arabic full stop return text def normalize_punctuation_spacing(text: str) -> str: """ Add spaces after punctuation marks if missing. This fixes formatting issues where punctuation is not followed by spaces, e.g., "word1!word2" → "word1! word2" This ensures the text is properly formatted before label extraction, so the model learns from clean, well-formatted examples. Args: text: Input text Returns: Text with proper spacing after punctuation """ # Add space after each punctuation mark if not already present # Pattern: punctuation followed by non-space character for punct in ['?', '،', ':', '؛', '!', '.']: # Use regex to add space after punctuation if followed by non-space text = re.sub(f'({re.escape(punct)})([^ \n\t])', r'\1 \2', text) return text def remove_tatweel(text: str) -> str: """ Remove Arabic tatweel character (ـ). Tatweel is used for text justification but has no linguistic value. Args: text: Input Arabic text Returns: Text without tatweel """ return re.sub('ـ', '', text) def remove_punctuation_marks(text: str) -> str: """ Remove all punctuation marks from text. This function removes the target punctuation marks used in the problem: - Question marks (?) - Colons (:) - Commas (،) - Semicolons (؛) - Exclamation marks (!) - Full stops (.) Args: text: Input text with punctuation marks Returns: Text without punctuation marks """ # Remove all target punctuation marks for punct in PUNCTUATION_CHARS.keys(): text = text.replace(punct, '') # Remove extra whitespace that may result from punctuation removal text = ' '.join(text.split()) return text def preprocess_text(text: str) -> str: """ Apply all preprocessing steps to text. Args: text: Input text Returns: Preprocessed text """ text = remove_diacritics(text) text = normalize_arabic_text(text) text = remove_tatweel(text) text = normalize_punctuation(text) text = normalize_punctuation_spacing(text) # Remove extra whitespace text = ' '.join(text.split()) return text # ============================================================================ # LABEL EXTRACTION FUNCTION (CRITICAL) # ============================================================================ def extract_words_and_labels(text: str, preprocess: bool = True) -> Tuple[List[str], List[int]]: """ Extract words and their corresponding punctuation labels from text. This is the CORE function that converts punctuated text into training data. Example: Input: "أكل الولد الخبز، وشرب الماء." Output: words = ['أكل', 'الولد', 'الخبز', 'وشرب', 'الماء'] labels = [0, 0, 2, 0, 6] # 0=none, 2=comma, 6=period Args: text: Raw text with punctuation preprocess: Whether to apply preprocessing Returns: Tuple of (words, labels) - words: List of words without punctuation - labels: List of punctuation labels (0-6) """ if preprocess: text = preprocess_text(text) words = [] labels = [] # Split by whitespace and process each token # After preprocessing, punctuation should be properly spaced, # so each token is either a word with trailing punctuation or just a word tokens = text.split() for token in tokens: if not token.strip(): continue # Check if token ends with punctuation word = token label = 0 # No punctuation by default if len(token) > 0 and token[-1] in PUNCTUATION_CHARS: punct_char = token[-1] label = PUNCTUATION_CHARS[punct_char] word = token[:-1] # Remove punctuation from word if word.strip(): words.append(word.strip()) labels.append(label) return words, labels def _process_text_batch(batch_texts: List[str]) -> List[Tuple[List[str], List[int]]]: """ Helper function to process a batch of texts. Args: batch_texts: List of text strings to process Returns: List of (words, labels) tuples """ batch_sequences = [] for text in batch_texts: # Split after periods while keeping them attached to sentences sentences = re.split(r'(?<=\.)\s+', text) for sent in sentences: if not sent.strip(): continue words, labels = extract_words_and_labels(sent) if 3 <= len(words) <= 100: batch_sequences.append((words, labels)) return batch_sequences def process_texts_in_batches(texts: List[str], batch_size: int = 10, max_texts: Optional[int] = None, output_dir: str = 'processed_data') -> int: """ Process texts in batches and save to disk incrementally. Args: texts: List of raw texts batch_size: Number of texts to process at once max_texts: Maximum number of texts to process output_dir: Directory to save batch files Returns: Total number of sequences processed """ import os import gc os.makedirs(output_dir, exist_ok=True) texts_to_process = texts[:max_texts] if max_texts else texts total_sequences = 0 batch_num = 0 # Process in batches for i in tqdm(range(0, len(texts_to_process), batch_size), desc="Processing batches"): batch_texts = texts_to_process[i:i+batch_size] batch_sequences = _process_text_batch(batch_texts) # Save batch to disk if batch_sequences: batch_file = os.path.join(output_dir, f'batch_{batch_num:04d}.pkl') with open(batch_file, 'wb') as f: pickle.dump(batch_sequences, f) total_sequences += len(batch_sequences) batch_num += 1 # Clear memory del batch_sequences gc.collect() print(f"\nProcessed {total_sequences:,} sequences in {batch_num} batches") print(f"Batch files saved to {output_dir}/") return total_sequences # ============================================================================ # EXPLORATORY DATA ANALYSIS # ============================================================================ def get_dataset_statistics_incremental(batch_dir: str) -> Dict: """ Calculate dataset statistics incrementally from batch files. Memory-efficient for large datasets - processes one batch at a time. Args: batch_dir: Directory containing batch files Returns: Dictionary with statistics """ import glob import gc batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) # Initialize accumulators sequence_lengths = [] label_counts = Counter() unique_words = set() num_sequences = 0 num_words = 0 print(f"Calculating statistics from {len(batch_files)} batch files...") for batch_file in tqdm(batch_files, desc="Processing batches for stats"): with open(batch_file, 'rb') as f: batch_sequences = pickle.load(f) # Process this batch for words, labels in batch_sequences: # Sequence length seq_len = len(words) sequence_lengths.append(seq_len) # Labels label_counts.update(labels) # Words unique_words.update(words) num_words += seq_len num_sequences += 1 # Clear memory del batch_sequences gc.collect() # Calculate statistics stats = { 'num_sequences': num_sequences, 'num_words': num_words, 'num_unique_words': len(unique_words), 'sequence_length': { 'min': int(np.min(sequence_lengths)), 'max': int(np.max(sequence_lengths)), 'mean': float(np.mean(sequence_lengths)), 'median': float(np.median(sequence_lengths)), 'std': float(np.std(sequence_lengths)), 'percentile_95': float(np.percentile(sequence_lengths, 95)), 'percentile_99': float(np.percentile(sequence_lengths, 99)) }, 'label_distribution': dict(label_counts), 'class_imbalance_ratio': float(label_counts[0] / sum(v for k, v in label_counts.items() if k != 0)) } # Clear memory del unique_words, sequence_lengths gc.collect() return stats def split_batches_to_directories(batch_dir: str, output_dir: str, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = SEED) -> Tuple[str, str, str]: """ Split sequences into train/val/test batch directories. MEMORY EFFICIENT: Routes sequences to split directories without loading all data. Instead of creating single large files, creates: - output_dir/batches_train/batch_0000.pkl, batch_0001.pkl, ... - output_dir/batches_val/batch_0000.pkl, batch_0001.pkl, ... - output_dir/batches_test/batch_0000.pkl, batch_0001.pkl, ... Args: batch_dir: Directory containing batch_*.pkl files output_dir: Base directory for split batch directories train_ratio: Ratio for training set val_ratio: Ratio for validation set test_ratio: Ratio for test set seed: Random seed for reproducibility Returns: Tuple of (train_dir, val_dir, test_dir) - directories containing split batches """ import glob import random assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \ "Ratios must sum to 1.0" # Get all batch files batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) if not batch_files: raise FileNotFoundError(f"No batch files found in {batch_dir}") print("=" * 80) print("SPLITTING DATA (Memory-Efficient - No Merging)") print("=" * 80) print(f"Processing from: {batch_dir}") print(f"Saving to: {output_dir}\n") # Create split directories train_dir = os.path.join(output_dir, 'batches_train') val_dir = os.path.join(output_dir, 'batches_val') test_dir = os.path.join(output_dir, 'batches_test') os.makedirs(train_dir, exist_ok=True) os.makedirs(val_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True) # Step 1: Count total sequences print("Step 1: Counting total sequences...") total_sequences = 0 for batch_file in tqdm(batch_files, desc="Counting"): with open(batch_file, 'rb') as f: batch_seqs = pickle.load(f) total_sequences += len(batch_seqs) del batch_seqs gc.collect() print(f"✓ Total sequences: {total_sequences:,}\n") # Step 2: Generate shuffled indices and determine splits print("Step 2: Generating split assignments...") random.seed(seed) indices = list(range(total_sequences)) random.shuffle(indices) # Calculate split points train_end = int(train_ratio * total_sequences) val_end = train_end + int(val_ratio * total_sequences) train_indices = set(indices[:train_end]) val_indices = set(indices[train_end:val_end]) test_indices = set(indices[val_end:]) print(f"✓ Train: {len(train_indices):,} sequences ({len(train_indices)/total_sequences*100:.1f}%)") print(f"✓ Val: {len(val_indices):,} sequences ({len(val_indices)/total_sequences*100:.1f}%)") print(f"✓ Test: {len(test_indices):,} sequences ({len(test_indices)/total_sequences*100:.1f}%)\n") # Step 3: Process batches and route to split directories print("Step 3: Routing sequences to split directories...") current_idx = 0 train_batch_num = 0 val_batch_num = 0 test_batch_num = 0 for batch_file in tqdm(batch_files, desc="Splitting batches"): with open(batch_file, 'rb') as f: batch_seqs = pickle.load(f) # Accumulate sequences for this batch train_batch = [] val_batch = [] test_batch = [] for seq in batch_seqs: if current_idx in train_indices: train_batch.append(seq) elif current_idx in val_indices: val_batch.append(seq) else: # test_indices test_batch.append(seq) current_idx += 1 # Save split batches (only if non-empty) if train_batch: with open(os.path.join(train_dir, f'batch_{train_batch_num:04d}.pkl'), 'wb') as f: pickle.dump(train_batch, f) train_batch_num += 1 if val_batch: with open(os.path.join(val_dir, f'batch_{val_batch_num:04d}.pkl'), 'wb') as f: pickle.dump(val_batch, f) val_batch_num += 1 if test_batch: with open(os.path.join(test_dir, f'batch_{test_batch_num:04d}.pkl'), 'wb') as f: pickle.dump(test_batch, f) test_batch_num += 1 # Clear memory del batch_seqs, train_batch, val_batch, test_batch gc.collect() print(f"\n✓ Split complete!") print(f" Train: {train_batch_num} batch files in {train_dir}") print(f" Val: {val_batch_num} batch files in {val_dir}") print(f" Test: {test_batch_num} batch files in {test_dir}") print("\n" + "=" * 80) print("✓ DATA SPLIT COMPLETE!") print("=" * 80) return train_dir, val_dir, test_dir def load_sequences_from_batches(batch_dir: str) -> List[Tuple[List[str], List[int]]]: """ Load all sequences from a batch directory. Use this for small datasets or val/test sets. Args: batch_dir: Directory containing batch_*.pkl files Returns: List of (words, labels) tuples """ import glob batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) all_sequences = [] for batch_file in tqdm(batch_files, desc=f"Loading {os.path.basename(batch_dir)}"): with open(batch_file, 'rb') as f: batch_seqs = pickle.load(f) all_sequences.extend(batch_seqs) print(f"Loaded {len(all_sequences):,} sequences from {batch_dir}") return all_sequences def iterate_sequences_from_batches(batch_dir: str): """ Generator that yields sequences from batch directory one at a time. Args: batch_dir: Directory containing batch_*.pkl files Yields: (words, labels) tuples """ import glob batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) for batch_file in batch_files: with open(batch_file, 'rb') as f: batch_seqs = pickle.load(f) for seq in batch_seqs: yield seq # Clear memory after each batch del batch_seqs gc.collect() def count_sequences_in_batches(batch_dir: str) -> int: """ Count total sequences in a batch directory without loading all data. Args: batch_dir: Directory containing batch_*.pkl files Returns: Total number of sequences """ import glob batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) total = 0 for batch_file in batch_files: with open(batch_file, 'rb') as f: batch_seqs = pickle.load(f) total += len(batch_seqs) del batch_seqs gc.collect() return total def print_dataset_statistics(stats: Dict): """ Print dataset statistics in a readable format. Args: stats: Statistics dictionary from get_dataset_statistics """ print("=" * 80) print("DATASET STATISTICS") print("=" * 80) print(f"Total sequences: {stats['num_sequences']:,}") print(f"Total words: {stats['num_words']:,}") print(f"Unique words: {stats['num_unique_words']:,}") print(f"\nSequence length statistics:") for key, value in stats['sequence_length'].items(): print(f" {key}: {value:.2f}") print("\n" + "=" * 80) print("LABEL DISTRIBUTION") print("=" * 80) for label in sorted(stats['label_distribution'].keys()): count = stats['label_distribution'][label] pct = (count / stats['num_words']) * 100 punct = LABEL_TO_PUNCT[label] punct_name = [k for k, v in PUNCTUATION_MAP.items() if v == label][0] print(f"{label}: {punct_name:15} '{punct}' - {count:7,} ({pct:5.2f}%)") print(f"\nClass imbalance ratio (no_punct/punct): {stats['class_imbalance_ratio']:.2f}") # ============================================================================ # VOCABULARY CLASS # ============================================================================ class Vocabulary: """ Vocabulary class for word-to-index and index-to-word mappings. This class handles: - Building vocabulary from sequences - Word-to-index and index-to-word mappings - Special tokens (, ) - Frequency filtering - Vocabulary size limiting """ def __init__(self, max_vocab_size: Optional[int] = None, min_freq: int = 2): """ Initialize vocabulary. Args: max_vocab_size: Maximum vocabulary size (None = unlimited) min_freq: Minimum word frequency to include in vocabulary """ self.word2idx = {'': 0, '': 1} self.idx2word = {0: '', 1: ''} self.word_freq = Counter() self.max_vocab_size = max_vocab_size self.min_freq = min_freq def build_vocab(self, sequences: List[Tuple[List[str], List[int]]]): """ Build vocabulary from sequences. Args: sequences: List of (words, labels) tuples """ # Count word frequencies for words, _ in sequences: self.word_freq.update(words) def build_vocab_incremental(self, batch_dir: str): """ Build vocabulary incrementally from batch files. Memory-efficient for large datasets. Args: batch_dir: Directory containing batch files """ import glob import gc batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) print(f"Building vocabulary from {len(batch_files)} batch files...") for batch_file in tqdm(batch_files, desc="Processing batches"): with open(batch_file, 'rb') as f: batch_sequences = pickle.load(f) # Count frequencies in this batch for words, _ in batch_sequences: self.word_freq.update(words) # Clear memory del batch_sequences gc.collect() # Filter by frequency and limit size filtered_words = [ word for word, freq in self.word_freq.items() if freq >= self.min_freq ] # Sort by frequency (most common first) filtered_words.sort(key=lambda w: self.word_freq[w], reverse=True) # Limit vocabulary size if self.max_vocab_size: filtered_words = filtered_words[:self.max_vocab_size - 2] # -2 for PAD and UNK # Build mappings for word in filtered_words: idx = len(self.word2idx) self.word2idx[word] = idx self.idx2word[idx] = word print(f"Vocabulary built with {len(self)} words") print(f" Total unique words in corpus: {len(self.word_freq)}") print(f" Words kept (freq >= {self.min_freq}): {len(filtered_words)}") print(f" Coverage: {len(self) / len(self.word_freq) * 100:.2f}%") def encode(self, words: List[str]) -> List[int]: """ Convert words to indices. Args: words: List of words Returns: List of indices """ return [self.word2idx.get(word, self.word2idx['']) for word in words] def decode(self, indices: List[int]) -> List[str]: """ Convert indices to words. Args: indices: List of indices Returns: List of words """ return [self.idx2word.get(idx, '') for idx in indices] def __len__(self): return len(self.word2idx) def save(self, path: str): """Save vocabulary to file.""" with open(path, 'wb') as f: pickle.dump(self, f) print(f"Vocabulary saved to {path}") @staticmethod def load(path: str) -> 'Vocabulary': """Load vocabulary from file.""" with open(path, 'rb') as f: vocab = pickle.load(f) print(f"Vocabulary loaded from {path}") return vocab # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def reconstruct_text_with_punctuation(words: List[str], labels: List[int]) -> str: """ Reconstruct text with punctuation from words and labels. Args: words: List of words labels: List of punctuation labels Returns: Text with punctuation marks """ result = [] for word, label in zip(words, labels): result.append(word) punct = LABEL_TO_PUNCT.get(label, '') if punct: result.append(punct) return ' '.join(result) def save_processed_data(sequences: List[Tuple[List[str], List[int]]], filepath: str): """ Save processed sequences to file. Args: sequences: List of (words, labels) tuples filepath: Path to save file """ with open(filepath, 'wb') as f: pickle.dump(sequences, f) print(f"Saved {len(sequences)} sequences to {filepath}") def load_processed_data(filepath: str) -> List[Tuple[List[str], List[int]]]: """ Load processed sequences from file. Args: filepath: Path to load file Returns: List of (words, labels) tuples """ with open(filepath, 'rb') as f: sequences = pickle.load(f) print(f"Loaded {len(sequences)} sequences from {filepath}") return sequences class BiLSTM(nn.Module): """ BiLSTM for sequence labeling. """ def __init__(self, vocab_size, embedding_dim=300, hidden_dim=256, num_layers=2, num_classes=7, dropout=0.3, pretrained_embeddings=None): super(BiLSTM, self).__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers # Embedding layer self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) if pretrained_embeddings is not None: self.embedding.weight.data.copy_(pretrained_embeddings) # BiLSTM self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0 ) # Dropout self.dropout = nn.Dropout(dropout) # Output layer (per timestep) self.fc = nn.Linear(hidden_dim * 2, num_classes) def forward(self, word_indices, lengths): # Embedding embedded = self.embedding(word_indices) embedded = self.dropout(embedded) # Pack sequence packed = pack_padded_sequence( embedded, lengths.cpu(), batch_first=True, enforce_sorted=False ) # BiLSTM packed_output, _ = self.lstm(packed) lstm_output, _ = pad_packed_sequence(packed_output, batch_first=True) lstm_output = self.dropout(lstm_output) # Output for each timestep logits = self.fc(lstm_output) return logits print("✓ Model architecture defined") class PunctuationRestorer: def __init__(self, repo_id="malkhuzanie/arabic-punctuation-checkpoints", model_file="approach1_best_model_256_wf_fc_loss_smoth_none_v1.pth", vocab_file="vocabulary.pkl"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Downloading vocabulary and model...") vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_file) model_path = hf_hub_download(repo_id=repo_id, filename=model_file) # This hack ensures pickle can find 'Vocabulary' class in this script # if it was pickled in a different namespace class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): if name == 'Vocabulary': return Vocabulary return super().find_class(module, name) with open(vocab_path, 'rb') as f: try: self.vocab = pickle.load(f) except: f.seek(0) self.vocab = CustomUnpickler(f).load() # Load Model checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) self.model = BiLSTM( vocab_size=len(self.vocab.word2idx), embedding_dim=300, hidden_dim=256, num_layers=2, num_classes=7 ).to(self.device) state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint self.model.load_state_dict(state_dict) self.model.eval() self.max_len = 256 def predict(self, text): # Prepare Input words, _ = extract_words_and_labels(text, preprocess=True) if not words: return text # Convert words to indices indices = [self.vocab.word2idx.get(w, self.vocab.word2idx['']) for w in words][:self.max_len] input_tensor = torch.tensor([indices], dtype=torch.long).to(self.device) lengths = torch.tensor([len(indices)], dtype=torch.long) # Predict with torch.no_grad(): logits = self.model(input_tensor, lengths) preds = torch.argmax(logits, dim=-1).cpu().numpy()[0] # Reconstruct result = [] # Ensure we don't go out of bounds if words were truncated by max_len limit = min(len(words), len(preds)) for word, label in zip(words[:limit], preds[:limit]): punct = LABEL_TO_PUNCT.get(label, '') result.append(word + punct) return ' '.join(result)