| | """ |
| | Shared Preprocessing Module for Arabic Text Punctuation Prediction |
| | |
| | This module contains all shared functionality for data preprocessing, |
| | label extraction, and dataset creation. |
| | |
| | Authors: Mohammed{Alkhuzanie, Fael} |
| | Date: December 2025 |
| | """ |
| |
|
| | import os |
| | import re |
| | import gc |
| | import pickle |
| | import zipfile |
| | from collections import Counter, defaultdict |
| | from typing import List, Tuple, Dict, Optional |
| | import numpy as np |
| | import pandas as pd |
| | from tqdm import tqdm |
| | import torch |
| | import torch.nn as nn |
| | import glob |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset |
| | from torch.nn.utils.rnn import pad_sequence |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | PUNCTUATION_MAP = { |
| | 'NO_PUNCT': 0, |
| | 'QUESTION': 1, |
| | 'COMMA': 2, |
| | 'COLON': 3, |
| | 'SEMICOLON': 4, |
| | 'EXCLAMATION': 5, |
| | 'PERIOD': 6 |
| | } |
| |
|
| | |
| | LABEL_TO_PUNCT = { |
| | 0: '', |
| | 1: '؟', |
| | 2: '،', |
| | 3: ':', |
| | 4: '؛', |
| | 5: '!', |
| | 6: '.' |
| | } |
| |
|
| | PUNCTUATION_CHARS = { |
| | '?': 1, |
| | '،': 2, |
| | ':': 3, |
| | '؛': 4, |
| | '!': 5, |
| | '.': 6 |
| | } |
| |
|
| | |
| | SEED = 42 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_text_files(data_dir: str) -> List[str]: |
| | """ |
| | Load all text files from the data directory. |
| | |
| | Args: |
| | data_dir: Path to data directory |
| | |
| | Returns: |
| | List of text strings |
| | """ |
| | texts = [] |
| |
|
| | |
| | for root, dirs, files in os.walk(data_dir): |
| | for file in files: |
| | if file.endswith('.txt'): |
| | file_path = os.path.join(root, file) |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | if content.strip(): |
| | texts.append(content) |
| | except Exception as e: |
| | print(f"Error reading {file_path}: {e}") |
| |
|
| | return texts |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def remove_diacritics(text: str) -> str: |
| | """ |
| | Remove Arabic diacritics (tashkeel) from text. |
| | |
| | Arabic diacritics are marks that indicate vowel sounds. |
| | Removing them is common in NLP to reduce vocabulary size. |
| | |
| | Args: |
| | text: Input Arabic text |
| | |
| | Returns: |
| | Text without diacritics |
| | """ |
| | |
| | arabic_diacritics = re.compile(r'[\u0617-\u061A\u064B-\u0652]') |
| | return arabic_diacritics.sub('', text) |
| |
|
| |
|
| | def normalize_arabic_text(text: str) -> str: |
| | """ |
| | Normalize Arabic characters to standard forms. |
| | |
| | Normalizations: |
| | - أ، إ، آ → ا (normalize alef variants) |
| | - ة → ه (taa marbuta to haa) |
| | |
| | Args: |
| | text: Input Arabic text |
| | |
| | Returns: |
| | Normalized text |
| | """ |
| | text = re.sub('[إأآا]', 'ا', text) |
| | text = re.sub('ة', 'ه', text) |
| | return text |
| |
|
| |
|
| | def normalize_punctuation(text: str) -> str: |
| | """ |
| | Normalize punctuation marks to their canonical forms. |
| | |
| | This ensures consistency across different Unicode representations |
| | of the same punctuation mark. |
| | |
| | Normalizations: |
| | - ؟ (Arabic question mark U+061F) → ? (English question mark) |
| | - , (English comma) → ، (Arabic comma) |
| | - ; (English semicolon) → ؛ (Arabic semicolon) |
| | - Keep periods, colons, and exclamation marks as-is |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Text with normalized punctuation |
| | """ |
| | |
| | text = text.replace('؟', '?') |
| |
|
| | |
| | text = text.replace(',', '،') |
| |
|
| | |
| | text = text.replace(';', '؛') |
| |
|
| | |
| | text = text.replace('\u06d4', '.') |
| |
|
| | return text |
| |
|
| |
|
| | def normalize_punctuation_spacing(text: str) -> str: |
| | """ |
| | Add spaces after punctuation marks if missing. |
| | |
| | This fixes formatting issues where punctuation is not followed by spaces, |
| | e.g., "word1!word2" → "word1! word2" |
| | |
| | This ensures the text is properly formatted before label extraction, |
| | so the model learns from clean, well-formatted examples. |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Text with proper spacing after punctuation |
| | """ |
| | |
| | |
| | for punct in ['?', '،', ':', '؛', '!', '.']: |
| | |
| | text = re.sub(f'({re.escape(punct)})([^ \n\t])', r'\1 \2', text) |
| |
|
| | return text |
| |
|
| |
|
| |
|
| | def remove_tatweel(text: str) -> str: |
| | """ |
| | Remove Arabic tatweel character (ـ). |
| | |
| | Tatweel is used for text justification but has no linguistic value. |
| | |
| | Args: |
| | text: Input Arabic text |
| | |
| | Returns: |
| | Text without tatweel |
| | """ |
| | return re.sub('ـ', '', text) |
| |
|
| |
|
| | def remove_punctuation_marks(text: str) -> str: |
| | """ |
| | Remove all punctuation marks from text. |
| | |
| | This function removes the target punctuation marks used in the problem: |
| | - Question marks (?) |
| | - Colons (:) |
| | - Commas (،) |
| | - Semicolons (؛) |
| | - Exclamation marks (!) |
| | - Full stops (.) |
| | |
| | Args: |
| | text: Input text with punctuation marks |
| | |
| | Returns: |
| | Text without punctuation marks |
| | """ |
| | |
| | for punct in PUNCTUATION_CHARS.keys(): |
| | text = text.replace(punct, '') |
| |
|
| | |
| | text = ' '.join(text.split()) |
| |
|
| | return text |
| |
|
| |
|
| | def preprocess_text(text: str) -> str: |
| | """ |
| | Apply all preprocessing steps to text. |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Preprocessed text |
| | """ |
| | text = remove_diacritics(text) |
| | text = normalize_arabic_text(text) |
| | text = remove_tatweel(text) |
| | text = normalize_punctuation(text) |
| | text = normalize_punctuation_spacing(text) |
| |
|
| | |
| | text = ' '.join(text.split()) |
| |
|
| | return text |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def extract_words_and_labels(text: str, preprocess: bool = True) -> Tuple[List[str], List[int]]: |
| | """ |
| | Extract words and their corresponding punctuation labels from text. |
| | |
| | This is the CORE function that converts punctuated text into training data. |
| | |
| | Example: |
| | Input: "أكل الولد الخبز، وشرب الماء." |
| | Output: |
| | words = ['أكل', 'الولد', 'الخبز', 'وشرب', 'الماء'] |
| | labels = [0, 0, 2, 0, 6] # 0=none, 2=comma, 6=period |
| | |
| | Args: |
| | text: Raw text with punctuation |
| | preprocess: Whether to apply preprocessing |
| | |
| | Returns: |
| | Tuple of (words, labels) |
| | - words: List of words without punctuation |
| | - labels: List of punctuation labels (0-6) |
| | """ |
| | if preprocess: |
| | text = preprocess_text(text) |
| |
|
| | words = [] |
| | labels = [] |
| |
|
| | |
| | |
| | |
| | tokens = text.split() |
| |
|
| | for token in tokens: |
| | if not token.strip(): |
| | continue |
| |
|
| | |
| | word = token |
| | label = 0 |
| |
|
| | if len(token) > 0 and token[-1] in PUNCTUATION_CHARS: |
| | punct_char = token[-1] |
| | label = PUNCTUATION_CHARS[punct_char] |
| | word = token[:-1] |
| |
|
| | if word.strip(): |
| | words.append(word.strip()) |
| | labels.append(label) |
| |
|
| | return words, labels |
| |
|
| |
|
| | def _process_text_batch(batch_texts: List[str]) -> List[Tuple[List[str], List[int]]]: |
| | """ |
| | Helper function to process a batch of texts. |
| | |
| | Args: |
| | batch_texts: List of text strings to process |
| | |
| | Returns: |
| | List of (words, labels) tuples |
| | """ |
| | batch_sequences = [] |
| |
|
| | for text in batch_texts: |
| | |
| | sentences = re.split(r'(?<=\.)\s+', text) |
| |
|
| | for sent in sentences: |
| | if not sent.strip(): |
| | continue |
| |
|
| | words, labels = extract_words_and_labels(sent) |
| |
|
| | if 3 <= len(words) <= 100: |
| | batch_sequences.append((words, labels)) |
| |
|
| | return batch_sequences |
| |
|
| |
|
| | def process_texts_in_batches(texts: List[str], |
| | batch_size: int = 10, |
| | max_texts: Optional[int] = None, |
| | output_dir: str = 'processed_data') -> int: |
| | """ |
| | Process texts in batches and save to disk incrementally. |
| | |
| | Args: |
| | texts: List of raw texts |
| | batch_size: Number of texts to process at once |
| | max_texts: Maximum number of texts to process |
| | output_dir: Directory to save batch files |
| | |
| | Returns: |
| | Total number of sequences processed |
| | """ |
| | import os |
| | import gc |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | texts_to_process = texts[:max_texts] if max_texts else texts |
| | total_sequences = 0 |
| | batch_num = 0 |
| |
|
| | |
| | for i in tqdm(range(0, len(texts_to_process), batch_size), desc="Processing batches"): |
| | batch_texts = texts_to_process[i:i+batch_size] |
| | batch_sequences = _process_text_batch(batch_texts) |
| |
|
| | |
| | if batch_sequences: |
| | batch_file = os.path.join(output_dir, f'batch_{batch_num:04d}.pkl') |
| | with open(batch_file, 'wb') as f: |
| | pickle.dump(batch_sequences, f) |
| |
|
| | total_sequences += len(batch_sequences) |
| | batch_num += 1 |
| |
|
| | |
| | del batch_sequences |
| | gc.collect() |
| |
|
| | print(f"\nProcessed {total_sequences:,} sequences in {batch_num} batches") |
| | print(f"Batch files saved to {output_dir}/") |
| |
|
| | return total_sequences |
| |
|
| | |
| | |
| | |
| | def get_dataset_statistics_incremental(batch_dir: str) -> Dict: |
| | """ |
| | Calculate dataset statistics incrementally from batch files. |
| | Memory-efficient for large datasets - processes one batch at a time. |
| | |
| | Args: |
| | batch_dir: Directory containing batch files |
| | |
| | Returns: |
| | Dictionary with statistics |
| | """ |
| | import glob |
| | import gc |
| |
|
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| |
|
| | |
| | sequence_lengths = [] |
| | label_counts = Counter() |
| | unique_words = set() |
| | num_sequences = 0 |
| | num_words = 0 |
| |
|
| | print(f"Calculating statistics from {len(batch_files)} batch files...") |
| |
|
| | for batch_file in tqdm(batch_files, desc="Processing batches for stats"): |
| | with open(batch_file, 'rb') as f: |
| | batch_sequences = pickle.load(f) |
| |
|
| | |
| | for words, labels in batch_sequences: |
| | |
| | seq_len = len(words) |
| | sequence_lengths.append(seq_len) |
| |
|
| | |
| | label_counts.update(labels) |
| |
|
| | |
| | unique_words.update(words) |
| | num_words += seq_len |
| | num_sequences += 1 |
| |
|
| | |
| | del batch_sequences |
| | gc.collect() |
| |
|
| | |
| | stats = { |
| | 'num_sequences': num_sequences, |
| | 'num_words': num_words, |
| | 'num_unique_words': len(unique_words), |
| | 'sequence_length': { |
| | 'min': int(np.min(sequence_lengths)), |
| | 'max': int(np.max(sequence_lengths)), |
| | 'mean': float(np.mean(sequence_lengths)), |
| | 'median': float(np.median(sequence_lengths)), |
| | 'std': float(np.std(sequence_lengths)), |
| | 'percentile_95': float(np.percentile(sequence_lengths, 95)), |
| | 'percentile_99': float(np.percentile(sequence_lengths, 99)) |
| | }, |
| | 'label_distribution': dict(label_counts), |
| | 'class_imbalance_ratio': float(label_counts[0] / sum(v for k, v in label_counts.items() if k != 0)) |
| | } |
| |
|
| | |
| | del unique_words, sequence_lengths |
| | gc.collect() |
| |
|
| | return stats |
| |
|
| |
|
| | def split_batches_to_directories(batch_dir: str, |
| | output_dir: str, |
| | train_ratio: float = 0.8, |
| | val_ratio: float = 0.1, |
| | test_ratio: float = 0.1, |
| | seed: int = SEED) -> Tuple[str, str, str]: |
| | """ |
| | Split sequences into train/val/test batch directories. |
| | MEMORY EFFICIENT: Routes sequences to split directories without loading all data. |
| | |
| | Instead of creating single large files, creates: |
| | - output_dir/batches_train/batch_0000.pkl, batch_0001.pkl, ... |
| | - output_dir/batches_val/batch_0000.pkl, batch_0001.pkl, ... |
| | - output_dir/batches_test/batch_0000.pkl, batch_0001.pkl, ... |
| | |
| | Args: |
| | batch_dir: Directory containing batch_*.pkl files |
| | output_dir: Base directory for split batch directories |
| | train_ratio: Ratio for training set |
| | val_ratio: Ratio for validation set |
| | test_ratio: Ratio for test set |
| | seed: Random seed for reproducibility |
| | |
| | Returns: |
| | Tuple of (train_dir, val_dir, test_dir) - directories containing split batches |
| | """ |
| | import glob |
| | import random |
| |
|
| | assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \ |
| | "Ratios must sum to 1.0" |
| |
|
| | |
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| |
|
| | if not batch_files: |
| | raise FileNotFoundError(f"No batch files found in {batch_dir}") |
| |
|
| | print("=" * 80) |
| | print("SPLITTING DATA (Memory-Efficient - No Merging)") |
| | print("=" * 80) |
| | print(f"Processing from: {batch_dir}") |
| | print(f"Saving to: {output_dir}\n") |
| |
|
| | |
| | train_dir = os.path.join(output_dir, 'batches_train') |
| | val_dir = os.path.join(output_dir, 'batches_val') |
| | test_dir = os.path.join(output_dir, 'batches_test') |
| |
|
| | os.makedirs(train_dir, exist_ok=True) |
| | os.makedirs(val_dir, exist_ok=True) |
| | os.makedirs(test_dir, exist_ok=True) |
| |
|
| | |
| | print("Step 1: Counting total sequences...") |
| | total_sequences = 0 |
| | for batch_file in tqdm(batch_files, desc="Counting"): |
| | with open(batch_file, 'rb') as f: |
| | batch_seqs = pickle.load(f) |
| | total_sequences += len(batch_seqs) |
| | del batch_seqs |
| | gc.collect() |
| |
|
| | print(f"✓ Total sequences: {total_sequences:,}\n") |
| |
|
| | |
| | print("Step 2: Generating split assignments...") |
| | random.seed(seed) |
| | indices = list(range(total_sequences)) |
| | random.shuffle(indices) |
| |
|
| | |
| | train_end = int(train_ratio * total_sequences) |
| | val_end = train_end + int(val_ratio * total_sequences) |
| |
|
| | train_indices = set(indices[:train_end]) |
| | val_indices = set(indices[train_end:val_end]) |
| | test_indices = set(indices[val_end:]) |
| |
|
| | print(f"✓ Train: {len(train_indices):,} sequences ({len(train_indices)/total_sequences*100:.1f}%)") |
| | print(f"✓ Val: {len(val_indices):,} sequences ({len(val_indices)/total_sequences*100:.1f}%)") |
| | print(f"✓ Test: {len(test_indices):,} sequences ({len(test_indices)/total_sequences*100:.1f}%)\n") |
| |
|
| | |
| | print("Step 3: Routing sequences to split directories...") |
| | current_idx = 0 |
| | train_batch_num = 0 |
| | val_batch_num = 0 |
| | test_batch_num = 0 |
| |
|
| | for batch_file in tqdm(batch_files, desc="Splitting batches"): |
| | with open(batch_file, 'rb') as f: |
| | batch_seqs = pickle.load(f) |
| |
|
| | |
| | train_batch = [] |
| | val_batch = [] |
| | test_batch = [] |
| |
|
| | for seq in batch_seqs: |
| | if current_idx in train_indices: |
| | train_batch.append(seq) |
| | elif current_idx in val_indices: |
| | val_batch.append(seq) |
| | else: |
| | test_batch.append(seq) |
| | current_idx += 1 |
| |
|
| | |
| | if train_batch: |
| | with open(os.path.join(train_dir, f'batch_{train_batch_num:04d}.pkl'), 'wb') as f: |
| | pickle.dump(train_batch, f) |
| | train_batch_num += 1 |
| |
|
| | if val_batch: |
| | with open(os.path.join(val_dir, f'batch_{val_batch_num:04d}.pkl'), 'wb') as f: |
| | pickle.dump(val_batch, f) |
| | val_batch_num += 1 |
| |
|
| | if test_batch: |
| | with open(os.path.join(test_dir, f'batch_{test_batch_num:04d}.pkl'), 'wb') as f: |
| | pickle.dump(test_batch, f) |
| | test_batch_num += 1 |
| |
|
| | |
| | del batch_seqs, train_batch, val_batch, test_batch |
| | gc.collect() |
| |
|
| | print(f"\n✓ Split complete!") |
| | print(f" Train: {train_batch_num} batch files in {train_dir}") |
| | print(f" Val: {val_batch_num} batch files in {val_dir}") |
| | print(f" Test: {test_batch_num} batch files in {test_dir}") |
| |
|
| | print("\n" + "=" * 80) |
| | print("✓ DATA SPLIT COMPLETE!") |
| | print("=" * 80) |
| |
|
| | return train_dir, val_dir, test_dir |
| |
|
| |
|
| | def load_sequences_from_batches(batch_dir: str) -> List[Tuple[List[str], List[int]]]: |
| | """ |
| | Load all sequences from a batch directory. |
| | Use this for small datasets or val/test sets. |
| | |
| | Args: |
| | batch_dir: Directory containing batch_*.pkl files |
| | |
| | Returns: |
| | List of (words, labels) tuples |
| | """ |
| | import glob |
| |
|
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| | all_sequences = [] |
| |
|
| | for batch_file in tqdm(batch_files, desc=f"Loading {os.path.basename(batch_dir)}"): |
| | with open(batch_file, 'rb') as f: |
| | batch_seqs = pickle.load(f) |
| | all_sequences.extend(batch_seqs) |
| |
|
| | print(f"Loaded {len(all_sequences):,} sequences from {batch_dir}") |
| | return all_sequences |
| |
|
| |
|
| | def iterate_sequences_from_batches(batch_dir: str): |
| | """ |
| | Generator that yields sequences from batch directory one at a time. |
| | |
| | Args: |
| | batch_dir: Directory containing batch_*.pkl files |
| | |
| | Yields: |
| | (words, labels) tuples |
| | """ |
| | import glob |
| |
|
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| |
|
| | for batch_file in batch_files: |
| | with open(batch_file, 'rb') as f: |
| | batch_seqs = pickle.load(f) |
| |
|
| | for seq in batch_seqs: |
| | yield seq |
| |
|
| | |
| | del batch_seqs |
| | gc.collect() |
| |
|
| |
|
| | def count_sequences_in_batches(batch_dir: str) -> int: |
| | """ |
| | Count total sequences in a batch directory without loading all data. |
| | |
| | Args: |
| | batch_dir: Directory containing batch_*.pkl files |
| | |
| | Returns: |
| | Total number of sequences |
| | """ |
| | import glob |
| |
|
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| | total = 0 |
| |
|
| | for batch_file in batch_files: |
| | with open(batch_file, 'rb') as f: |
| | batch_seqs = pickle.load(f) |
| | total += len(batch_seqs) |
| | del batch_seqs |
| | gc.collect() |
| |
|
| | return total |
| |
|
| |
|
| | def print_dataset_statistics(stats: Dict): |
| | """ |
| | Print dataset statistics in a readable format. |
| | |
| | Args: |
| | stats: Statistics dictionary from get_dataset_statistics |
| | """ |
| | print("=" * 80) |
| | print("DATASET STATISTICS") |
| | print("=" * 80) |
| | print(f"Total sequences: {stats['num_sequences']:,}") |
| | print(f"Total words: {stats['num_words']:,}") |
| | print(f"Unique words: {stats['num_unique_words']:,}") |
| |
|
| | print(f"\nSequence length statistics:") |
| | for key, value in stats['sequence_length'].items(): |
| | print(f" {key}: {value:.2f}") |
| |
|
| | print("\n" + "=" * 80) |
| | print("LABEL DISTRIBUTION") |
| | print("=" * 80) |
| | for label in sorted(stats['label_distribution'].keys()): |
| | count = stats['label_distribution'][label] |
| | pct = (count / stats['num_words']) * 100 |
| | punct = LABEL_TO_PUNCT[label] |
| | punct_name = [k for k, v in PUNCTUATION_MAP.items() if v == label][0] |
| | print(f"{label}: {punct_name:15} '{punct}' - {count:7,} ({pct:5.2f}%)") |
| |
|
| | print(f"\nClass imbalance ratio (no_punct/punct): {stats['class_imbalance_ratio']:.2f}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Vocabulary: |
| | """ |
| | Vocabulary class for word-to-index and index-to-word mappings. |
| | |
| | This class handles: |
| | - Building vocabulary from sequences |
| | - Word-to-index and index-to-word mappings |
| | - Special tokens (<PAD>, <UNK>) |
| | - Frequency filtering |
| | - Vocabulary size limiting |
| | """ |
| |
|
| | def __init__(self, max_vocab_size: Optional[int] = None, min_freq: int = 2): |
| | """ |
| | Initialize vocabulary. |
| | |
| | Args: |
| | max_vocab_size: Maximum vocabulary size (None = unlimited) |
| | min_freq: Minimum word frequency to include in vocabulary |
| | """ |
| | self.word2idx = {'<PAD>': 0, '<UNK>': 1} |
| | self.idx2word = {0: '<PAD>', 1: '<UNK>'} |
| | self.word_freq = Counter() |
| | self.max_vocab_size = max_vocab_size |
| | self.min_freq = min_freq |
| |
|
| | def build_vocab(self, sequences: List[Tuple[List[str], List[int]]]): |
| | """ |
| | Build vocabulary from sequences. |
| | |
| | Args: |
| | sequences: List of (words, labels) tuples |
| | """ |
| | |
| | for words, _ in sequences: |
| | self.word_freq.update(words) |
| |
|
| | def build_vocab_incremental(self, batch_dir: str): |
| | """ |
| | Build vocabulary incrementally from batch files. |
| | Memory-efficient for large datasets. |
| | |
| | Args: |
| | batch_dir: Directory containing batch files |
| | """ |
| | import glob |
| | import gc |
| |
|
| | batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl'))) |
| |
|
| | print(f"Building vocabulary from {len(batch_files)} batch files...") |
| |
|
| | for batch_file in tqdm(batch_files, desc="Processing batches"): |
| | with open(batch_file, 'rb') as f: |
| | batch_sequences = pickle.load(f) |
| |
|
| | |
| | for words, _ in batch_sequences: |
| | self.word_freq.update(words) |
| |
|
| | |
| | del batch_sequences |
| | gc.collect() |
| |
|
| | |
| | filtered_words = [ |
| | word for word, freq in self.word_freq.items() |
| | if freq >= self.min_freq |
| | ] |
| |
|
| | |
| | filtered_words.sort(key=lambda w: self.word_freq[w], reverse=True) |
| |
|
| | |
| | if self.max_vocab_size: |
| | filtered_words = filtered_words[:self.max_vocab_size - 2] |
| |
|
| | |
| | for word in filtered_words: |
| | idx = len(self.word2idx) |
| | self.word2idx[word] = idx |
| | self.idx2word[idx] = word |
| |
|
| | print(f"Vocabulary built with {len(self)} words") |
| | print(f" Total unique words in corpus: {len(self.word_freq)}") |
| | print(f" Words kept (freq >= {self.min_freq}): {len(filtered_words)}") |
| | print(f" Coverage: {len(self) / len(self.word_freq) * 100:.2f}%") |
| |
|
| | def encode(self, words: List[str]) -> List[int]: |
| | """ |
| | Convert words to indices. |
| | |
| | Args: |
| | words: List of words |
| | |
| | Returns: |
| | List of indices |
| | """ |
| | return [self.word2idx.get(word, self.word2idx['<UNK>']) for word in words] |
| |
|
| | def decode(self, indices: List[int]) -> List[str]: |
| | """ |
| | Convert indices to words. |
| | |
| | Args: |
| | indices: List of indices |
| | |
| | Returns: |
| | List of words |
| | """ |
| | return [self.idx2word.get(idx, '<UNK>') for idx in indices] |
| |
|
| | def __len__(self): |
| | return len(self.word2idx) |
| |
|
| | def save(self, path: str): |
| | """Save vocabulary to file.""" |
| | with open(path, 'wb') as f: |
| | pickle.dump(self, f) |
| | print(f"Vocabulary saved to {path}") |
| |
|
| | @staticmethod |
| | def load(path: str) -> 'Vocabulary': |
| | """Load vocabulary from file.""" |
| | with open(path, 'rb') as f: |
| | vocab = pickle.load(f) |
| | print(f"Vocabulary loaded from {path}") |
| | return vocab |
| |
|
| | |
| | |
| | |
| |
|
| | def reconstruct_text_with_punctuation(words: List[str], labels: List[int]) -> str: |
| | """ |
| | Reconstruct text with punctuation from words and labels. |
| | |
| | Args: |
| | words: List of words |
| | labels: List of punctuation labels |
| | |
| | Returns: |
| | Text with punctuation marks |
| | """ |
| | result = [] |
| | for word, label in zip(words, labels): |
| | result.append(word) |
| | punct = LABEL_TO_PUNCT.get(label, '') |
| | if punct: |
| | result.append(punct) |
| |
|
| | return ' '.join(result) |
| |
|
| |
|
| | def save_processed_data(sequences: List[Tuple[List[str], List[int]]], |
| | filepath: str): |
| | """ |
| | Save processed sequences to file. |
| | |
| | Args: |
| | sequences: List of (words, labels) tuples |
| | filepath: Path to save file |
| | """ |
| | with open(filepath, 'wb') as f: |
| | pickle.dump(sequences, f) |
| | print(f"Saved {len(sequences)} sequences to {filepath}") |
| |
|
| |
|
| | def load_processed_data(filepath: str) -> List[Tuple[List[str], List[int]]]: |
| | """ |
| | Load processed sequences from file. |
| | |
| | Args: |
| | filepath: Path to load file |
| | |
| | Returns: |
| | List of (words, labels) tuples |
| | """ |
| | with open(filepath, 'rb') as f: |
| | sequences = pickle.load(f) |
| | print(f"Loaded {len(sequences)} sequences from {filepath}") |
| | return sequences |
| |
|
| |
|
| | class BiLSTM(nn.Module): |
| | """ |
| | BiLSTM for sequence labeling. |
| | """ |
| | def __init__(self, vocab_size, embedding_dim=300, hidden_dim=256, |
| | num_layers=2, num_classes=7, dropout=0.3, pretrained_embeddings=None): |
| | super(BiLSTM, self).__init__() |
| |
|
| | self.hidden_dim = hidden_dim |
| | self.num_layers = num_layers |
| |
|
| | |
| | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) |
| | if pretrained_embeddings is not None: |
| | self.embedding.weight.data.copy_(pretrained_embeddings) |
| |
|
| | |
| | self.lstm = nn.LSTM( |
| | embedding_dim, hidden_dim, num_layers=num_layers, |
| | batch_first=True, bidirectional=True, |
| | dropout=dropout if num_layers > 1 else 0 |
| | ) |
| |
|
| | |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.fc = nn.Linear(hidden_dim * 2, num_classes) |
| |
|
| | def forward(self, word_indices, lengths): |
| | |
| | embedded = self.embedding(word_indices) |
| | embedded = self.dropout(embedded) |
| |
|
| | |
| | packed = pack_padded_sequence( |
| | embedded, lengths.cpu(), batch_first=True, enforce_sorted=False |
| | ) |
| |
|
| | |
| | packed_output, _ = self.lstm(packed) |
| | lstm_output, _ = pad_packed_sequence(packed_output, batch_first=True) |
| |
|
| | lstm_output = self.dropout(lstm_output) |
| |
|
| | |
| | logits = self.fc(lstm_output) |
| |
|
| | return logits |
| |
|
| | print("✓ Model architecture defined") |
| |
|
| |
|
| | class PunctuationRestorer: |
| | def __init__(self, repo_id="malkhuzanie/arabic-punctuation-checkpoints", |
| | model_file="approach1_best_model_256_wf_fc_loss_smoth_none_v1.pth", |
| | vocab_file="vocabulary.pkl"): |
| |
|
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | print("Downloading vocabulary and model...") |
| | vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_file) |
| | model_path = hf_hub_download(repo_id=repo_id, filename=model_file) |
| |
|
| | |
| | |
| | class CustomUnpickler(pickle.Unpickler): |
| | def find_class(self, module, name): |
| | if name == 'Vocabulary': |
| | return Vocabulary |
| | return super().find_class(module, name) |
| |
|
| | with open(vocab_path, 'rb') as f: |
| | try: |
| | self.vocab = pickle.load(f) |
| | except: |
| | f.seek(0) |
| | self.vocab = CustomUnpickler(f).load() |
| |
|
| | |
| | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| | self.model = BiLSTM( |
| | vocab_size=len(self.vocab.word2idx), |
| | embedding_dim=300, hidden_dim=256, num_layers=2, num_classes=7 |
| | ).to(self.device) |
| |
|
| | state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint |
| | self.model.load_state_dict(state_dict) |
| | self.model.eval() |
| | self.max_len = 256 |
| |
|
| | def predict(self, text): |
| | |
| | words, _ = extract_words_and_labels(text, preprocess=True) |
| |
|
| | if not words: return text |
| |
|
| | |
| | indices = [self.vocab.word2idx.get(w, self.vocab.word2idx['<UNK>']) for w in words][:self.max_len] |
| |
|
| | input_tensor = torch.tensor([indices], dtype=torch.long).to(self.device) |
| | lengths = torch.tensor([len(indices)], dtype=torch.long) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = self.model(input_tensor, lengths) |
| | preds = torch.argmax(logits, dim=-1).cpu().numpy()[0] |
| |
|
| | |
| | result = [] |
| | |
| | limit = min(len(words), len(preds)) |
| |
|
| | for word, label in zip(words[:limit], preds[:limit]): |
| | punct = LABEL_TO_PUNCT.get(label, '') |
| | result.append(word + punct) |
| |
|
| | return ' '.join(result) |
| |
|