malkhuzanie's picture
Update inference.py
1b244e4 verified
"""
Shared Preprocessing Module for Arabic Text Punctuation Prediction
This module contains all shared functionality for data preprocessing,
label extraction, and dataset creation.
Authors: Mohammed{Alkhuzanie, Fael}
Date: December 2025
"""
import os
import re
import gc
import pickle
import zipfile
from collections import Counter, defaultdict
from typing import List, Tuple, Dict, Optional
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import glob
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from huggingface_hub import hf_hub_download
# ============================================================================
# CONSTANTS AND MAPPINGS
# ============================================================================
# Punctuation mark labels
PUNCTUATION_MAP = {
'NO_PUNCT': 0,
'QUESTION': 1, # ?
'COMMA': 2, # ،
'COLON': 3, # :
'SEMICOLON': 4, # ؛
'EXCLAMATION': 5, # !
'PERIOD': 6 # .
}
# Reverse mapping for label to punctuation
LABEL_TO_PUNCT = {
0: '',
1: '؟',
2: '،',
3: ':',
4: '؛',
5: '!',
6: '.'
}
PUNCTUATION_CHARS = {
'?': 1, # Question mark (normalized from ؟)
'،': 2, # Arabic comma (normalized from ,)
':': 3, # Colon
'؛': 4, # Arabic semicolon (normalized from ;)
'!': 5, # Exclamation mark
'.': 6 # Period (normalized from \u06d4)
}
# Random seed for reproducibility
SEED = 42
# ============================================================================
# DATA LOADING FUNCTIONS
# ============================================================================
def load_text_files(data_dir: str) -> List[str]:
"""
Load all text files from the data directory.
Args:
data_dir: Path to data directory
Returns:
List of text strings
"""
texts = []
# Look for .txt files recursively
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.endswith('.txt'):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if content.strip(): # Only add non-empty content
texts.append(content)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return texts
# ============================================================================
# TEXT PREPROCESSING FUNCTIONS
# ============================================================================
def remove_diacritics(text: str) -> str:
"""
Remove Arabic diacritics (tashkeel) from text.
Arabic diacritics are marks that indicate vowel sounds.
Removing them is common in NLP to reduce vocabulary size.
Args:
text: Input Arabic text
Returns:
Text without diacritics
"""
# Arabic diacritics Unicode range
arabic_diacritics = re.compile(r'[\u0617-\u061A\u064B-\u0652]')
return arabic_diacritics.sub('', text)
def normalize_arabic_text(text: str) -> str:
"""
Normalize Arabic characters to standard forms.
Normalizations:
- أ، إ، آ → ا (normalize alef variants)
- ة → ه (taa marbuta to haa)
Args:
text: Input Arabic text
Returns:
Normalized text
"""
text = re.sub('[إأآا]', 'ا', text)
text = re.sub('ة', 'ه', text)
return text
def normalize_punctuation(text: str) -> str:
"""
Normalize punctuation marks to their canonical forms.
This ensures consistency across different Unicode representations
of the same punctuation mark.
Normalizations:
- ؟ (Arabic question mark U+061F) → ? (English question mark)
- , (English comma) → ، (Arabic comma)
- ; (English semicolon) → ؛ (Arabic semicolon)
- Keep periods, colons, and exclamation marks as-is
Args:
text: Input text
Returns:
Text with normalized punctuation
"""
# Normalize question marks to English ?
text = text.replace('؟', '?')
# Normalize commas to Arabic ،
text = text.replace(',', '،')
# Normalize semicolons to Arabic ؛
text = text.replace(';', '؛')
# Normalize periods (remove Arabic-specific period variants)
text = text.replace('\u06d4', '.') # Arabic full stop
return text
def normalize_punctuation_spacing(text: str) -> str:
"""
Add spaces after punctuation marks if missing.
This fixes formatting issues where punctuation is not followed by spaces,
e.g., "word1!word2" → "word1! word2"
This ensures the text is properly formatted before label extraction,
so the model learns from clean, well-formatted examples.
Args:
text: Input text
Returns:
Text with proper spacing after punctuation
"""
# Add space after each punctuation mark if not already present
# Pattern: punctuation followed by non-space character
for punct in ['?', '،', ':', '؛', '!', '.']:
# Use regex to add space after punctuation if followed by non-space
text = re.sub(f'({re.escape(punct)})([^ \n\t])', r'\1 \2', text)
return text
def remove_tatweel(text: str) -> str:
"""
Remove Arabic tatweel character (ـ).
Tatweel is used for text justification but has no linguistic value.
Args:
text: Input Arabic text
Returns:
Text without tatweel
"""
return re.sub('ـ', '', text)
def remove_punctuation_marks(text: str) -> str:
"""
Remove all punctuation marks from text.
This function removes the target punctuation marks used in the problem:
- Question marks (?)
- Colons (:)
- Commas (،)
- Semicolons (؛)
- Exclamation marks (!)
- Full stops (.)
Args:
text: Input text with punctuation marks
Returns:
Text without punctuation marks
"""
# Remove all target punctuation marks
for punct in PUNCTUATION_CHARS.keys():
text = text.replace(punct, '')
# Remove extra whitespace that may result from punctuation removal
text = ' '.join(text.split())
return text
def preprocess_text(text: str) -> str:
"""
Apply all preprocessing steps to text.
Args:
text: Input text
Returns:
Preprocessed text
"""
text = remove_diacritics(text)
text = normalize_arabic_text(text)
text = remove_tatweel(text)
text = normalize_punctuation(text)
text = normalize_punctuation_spacing(text)
# Remove extra whitespace
text = ' '.join(text.split())
return text
# ============================================================================
# LABEL EXTRACTION FUNCTION (CRITICAL)
# ============================================================================
def extract_words_and_labels(text: str, preprocess: bool = True) -> Tuple[List[str], List[int]]:
"""
Extract words and their corresponding punctuation labels from text.
This is the CORE function that converts punctuated text into training data.
Example:
Input: "أكل الولد الخبز، وشرب الماء."
Output:
words = ['أكل', 'الولد', 'الخبز', 'وشرب', 'الماء']
labels = [0, 0, 2, 0, 6] # 0=none, 2=comma, 6=period
Args:
text: Raw text with punctuation
preprocess: Whether to apply preprocessing
Returns:
Tuple of (words, labels)
- words: List of words without punctuation
- labels: List of punctuation labels (0-6)
"""
if preprocess:
text = preprocess_text(text)
words = []
labels = []
# Split by whitespace and process each token
# After preprocessing, punctuation should be properly spaced,
# so each token is either a word with trailing punctuation or just a word
tokens = text.split()
for token in tokens:
if not token.strip():
continue
# Check if token ends with punctuation
word = token
label = 0 # No punctuation by default
if len(token) > 0 and token[-1] in PUNCTUATION_CHARS:
punct_char = token[-1]
label = PUNCTUATION_CHARS[punct_char]
word = token[:-1] # Remove punctuation from word
if word.strip():
words.append(word.strip())
labels.append(label)
return words, labels
def _process_text_batch(batch_texts: List[str]) -> List[Tuple[List[str], List[int]]]:
"""
Helper function to process a batch of texts.
Args:
batch_texts: List of text strings to process
Returns:
List of (words, labels) tuples
"""
batch_sequences = []
for text in batch_texts:
# Split after periods while keeping them attached to sentences
sentences = re.split(r'(?<=\.)\s+', text)
for sent in sentences:
if not sent.strip():
continue
words, labels = extract_words_and_labels(sent)
if 3 <= len(words) <= 100:
batch_sequences.append((words, labels))
return batch_sequences
def process_texts_in_batches(texts: List[str],
batch_size: int = 10,
max_texts: Optional[int] = None,
output_dir: str = 'processed_data') -> int:
"""
Process texts in batches and save to disk incrementally.
Args:
texts: List of raw texts
batch_size: Number of texts to process at once
max_texts: Maximum number of texts to process
output_dir: Directory to save batch files
Returns:
Total number of sequences processed
"""
import os
import gc
os.makedirs(output_dir, exist_ok=True)
texts_to_process = texts[:max_texts] if max_texts else texts
total_sequences = 0
batch_num = 0
# Process in batches
for i in tqdm(range(0, len(texts_to_process), batch_size), desc="Processing batches"):
batch_texts = texts_to_process[i:i+batch_size]
batch_sequences = _process_text_batch(batch_texts)
# Save batch to disk
if batch_sequences:
batch_file = os.path.join(output_dir, f'batch_{batch_num:04d}.pkl')
with open(batch_file, 'wb') as f:
pickle.dump(batch_sequences, f)
total_sequences += len(batch_sequences)
batch_num += 1
# Clear memory
del batch_sequences
gc.collect()
print(f"\nProcessed {total_sequences:,} sequences in {batch_num} batches")
print(f"Batch files saved to {output_dir}/")
return total_sequences
# ============================================================================
# EXPLORATORY DATA ANALYSIS
# ============================================================================
def get_dataset_statistics_incremental(batch_dir: str) -> Dict:
"""
Calculate dataset statistics incrementally from batch files.
Memory-efficient for large datasets - processes one batch at a time.
Args:
batch_dir: Directory containing batch files
Returns:
Dictionary with statistics
"""
import glob
import gc
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
# Initialize accumulators
sequence_lengths = []
label_counts = Counter()
unique_words = set()
num_sequences = 0
num_words = 0
print(f"Calculating statistics from {len(batch_files)} batch files...")
for batch_file in tqdm(batch_files, desc="Processing batches for stats"):
with open(batch_file, 'rb') as f:
batch_sequences = pickle.load(f)
# Process this batch
for words, labels in batch_sequences:
# Sequence length
seq_len = len(words)
sequence_lengths.append(seq_len)
# Labels
label_counts.update(labels)
# Words
unique_words.update(words)
num_words += seq_len
num_sequences += 1
# Clear memory
del batch_sequences
gc.collect()
# Calculate statistics
stats = {
'num_sequences': num_sequences,
'num_words': num_words,
'num_unique_words': len(unique_words),
'sequence_length': {
'min': int(np.min(sequence_lengths)),
'max': int(np.max(sequence_lengths)),
'mean': float(np.mean(sequence_lengths)),
'median': float(np.median(sequence_lengths)),
'std': float(np.std(sequence_lengths)),
'percentile_95': float(np.percentile(sequence_lengths, 95)),
'percentile_99': float(np.percentile(sequence_lengths, 99))
},
'label_distribution': dict(label_counts),
'class_imbalance_ratio': float(label_counts[0] / sum(v for k, v in label_counts.items() if k != 0))
}
# Clear memory
del unique_words, sequence_lengths
gc.collect()
return stats
def split_batches_to_directories(batch_dir: str,
output_dir: str,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
seed: int = SEED) -> Tuple[str, str, str]:
"""
Split sequences into train/val/test batch directories.
MEMORY EFFICIENT: Routes sequences to split directories without loading all data.
Instead of creating single large files, creates:
- output_dir/batches_train/batch_0000.pkl, batch_0001.pkl, ...
- output_dir/batches_val/batch_0000.pkl, batch_0001.pkl, ...
- output_dir/batches_test/batch_0000.pkl, batch_0001.pkl, ...
Args:
batch_dir: Directory containing batch_*.pkl files
output_dir: Base directory for split batch directories
train_ratio: Ratio for training set
val_ratio: Ratio for validation set
test_ratio: Ratio for test set
seed: Random seed for reproducibility
Returns:
Tuple of (train_dir, val_dir, test_dir) - directories containing split batches
"""
import glob
import random
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
"Ratios must sum to 1.0"
# Get all batch files
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
if not batch_files:
raise FileNotFoundError(f"No batch files found in {batch_dir}")
print("=" * 80)
print("SPLITTING DATA (Memory-Efficient - No Merging)")
print("=" * 80)
print(f"Processing from: {batch_dir}")
print(f"Saving to: {output_dir}\n")
# Create split directories
train_dir = os.path.join(output_dir, 'batches_train')
val_dir = os.path.join(output_dir, 'batches_val')
test_dir = os.path.join(output_dir, 'batches_test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# Step 1: Count total sequences
print("Step 1: Counting total sequences...")
total_sequences = 0
for batch_file in tqdm(batch_files, desc="Counting"):
with open(batch_file, 'rb') as f:
batch_seqs = pickle.load(f)
total_sequences += len(batch_seqs)
del batch_seqs
gc.collect()
print(f"✓ Total sequences: {total_sequences:,}\n")
# Step 2: Generate shuffled indices and determine splits
print("Step 2: Generating split assignments...")
random.seed(seed)
indices = list(range(total_sequences))
random.shuffle(indices)
# Calculate split points
train_end = int(train_ratio * total_sequences)
val_end = train_end + int(val_ratio * total_sequences)
train_indices = set(indices[:train_end])
val_indices = set(indices[train_end:val_end])
test_indices = set(indices[val_end:])
print(f"✓ Train: {len(train_indices):,} sequences ({len(train_indices)/total_sequences*100:.1f}%)")
print(f"✓ Val: {len(val_indices):,} sequences ({len(val_indices)/total_sequences*100:.1f}%)")
print(f"✓ Test: {len(test_indices):,} sequences ({len(test_indices)/total_sequences*100:.1f}%)\n")
# Step 3: Process batches and route to split directories
print("Step 3: Routing sequences to split directories...")
current_idx = 0
train_batch_num = 0
val_batch_num = 0
test_batch_num = 0
for batch_file in tqdm(batch_files, desc="Splitting batches"):
with open(batch_file, 'rb') as f:
batch_seqs = pickle.load(f)
# Accumulate sequences for this batch
train_batch = []
val_batch = []
test_batch = []
for seq in batch_seqs:
if current_idx in train_indices:
train_batch.append(seq)
elif current_idx in val_indices:
val_batch.append(seq)
else: # test_indices
test_batch.append(seq)
current_idx += 1
# Save split batches (only if non-empty)
if train_batch:
with open(os.path.join(train_dir, f'batch_{train_batch_num:04d}.pkl'), 'wb') as f:
pickle.dump(train_batch, f)
train_batch_num += 1
if val_batch:
with open(os.path.join(val_dir, f'batch_{val_batch_num:04d}.pkl'), 'wb') as f:
pickle.dump(val_batch, f)
val_batch_num += 1
if test_batch:
with open(os.path.join(test_dir, f'batch_{test_batch_num:04d}.pkl'), 'wb') as f:
pickle.dump(test_batch, f)
test_batch_num += 1
# Clear memory
del batch_seqs, train_batch, val_batch, test_batch
gc.collect()
print(f"\n✓ Split complete!")
print(f" Train: {train_batch_num} batch files in {train_dir}")
print(f" Val: {val_batch_num} batch files in {val_dir}")
print(f" Test: {test_batch_num} batch files in {test_dir}")
print("\n" + "=" * 80)
print("✓ DATA SPLIT COMPLETE!")
print("=" * 80)
return train_dir, val_dir, test_dir
def load_sequences_from_batches(batch_dir: str) -> List[Tuple[List[str], List[int]]]:
"""
Load all sequences from a batch directory.
Use this for small datasets or val/test sets.
Args:
batch_dir: Directory containing batch_*.pkl files
Returns:
List of (words, labels) tuples
"""
import glob
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
all_sequences = []
for batch_file in tqdm(batch_files, desc=f"Loading {os.path.basename(batch_dir)}"):
with open(batch_file, 'rb') as f:
batch_seqs = pickle.load(f)
all_sequences.extend(batch_seqs)
print(f"Loaded {len(all_sequences):,} sequences from {batch_dir}")
return all_sequences
def iterate_sequences_from_batches(batch_dir: str):
"""
Generator that yields sequences from batch directory one at a time.
Args:
batch_dir: Directory containing batch_*.pkl files
Yields:
(words, labels) tuples
"""
import glob
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
for batch_file in batch_files:
with open(batch_file, 'rb') as f:
batch_seqs = pickle.load(f)
for seq in batch_seqs:
yield seq
# Clear memory after each batch
del batch_seqs
gc.collect()
def count_sequences_in_batches(batch_dir: str) -> int:
"""
Count total sequences in a batch directory without loading all data.
Args:
batch_dir: Directory containing batch_*.pkl files
Returns:
Total number of sequences
"""
import glob
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
total = 0
for batch_file in batch_files:
with open(batch_file, 'rb') as f:
batch_seqs = pickle.load(f)
total += len(batch_seqs)
del batch_seqs
gc.collect()
return total
def print_dataset_statistics(stats: Dict):
"""
Print dataset statistics in a readable format.
Args:
stats: Statistics dictionary from get_dataset_statistics
"""
print("=" * 80)
print("DATASET STATISTICS")
print("=" * 80)
print(f"Total sequences: {stats['num_sequences']:,}")
print(f"Total words: {stats['num_words']:,}")
print(f"Unique words: {stats['num_unique_words']:,}")
print(f"\nSequence length statistics:")
for key, value in stats['sequence_length'].items():
print(f" {key}: {value:.2f}")
print("\n" + "=" * 80)
print("LABEL DISTRIBUTION")
print("=" * 80)
for label in sorted(stats['label_distribution'].keys()):
count = stats['label_distribution'][label]
pct = (count / stats['num_words']) * 100
punct = LABEL_TO_PUNCT[label]
punct_name = [k for k, v in PUNCTUATION_MAP.items() if v == label][0]
print(f"{label}: {punct_name:15} '{punct}' - {count:7,} ({pct:5.2f}%)")
print(f"\nClass imbalance ratio (no_punct/punct): {stats['class_imbalance_ratio']:.2f}")
# ============================================================================
# VOCABULARY CLASS
# ============================================================================
class Vocabulary:
"""
Vocabulary class for word-to-index and index-to-word mappings.
This class handles:
- Building vocabulary from sequences
- Word-to-index and index-to-word mappings
- Special tokens (<PAD>, <UNK>)
- Frequency filtering
- Vocabulary size limiting
"""
def __init__(self, max_vocab_size: Optional[int] = None, min_freq: int = 2):
"""
Initialize vocabulary.
Args:
max_vocab_size: Maximum vocabulary size (None = unlimited)
min_freq: Minimum word frequency to include in vocabulary
"""
self.word2idx = {'<PAD>': 0, '<UNK>': 1}
self.idx2word = {0: '<PAD>', 1: '<UNK>'}
self.word_freq = Counter()
self.max_vocab_size = max_vocab_size
self.min_freq = min_freq
def build_vocab(self, sequences: List[Tuple[List[str], List[int]]]):
"""
Build vocabulary from sequences.
Args:
sequences: List of (words, labels) tuples
"""
# Count word frequencies
for words, _ in sequences:
self.word_freq.update(words)
def build_vocab_incremental(self, batch_dir: str):
"""
Build vocabulary incrementally from batch files.
Memory-efficient for large datasets.
Args:
batch_dir: Directory containing batch files
"""
import glob
import gc
batch_files = sorted(glob.glob(os.path.join(batch_dir, 'batch_*.pkl')))
print(f"Building vocabulary from {len(batch_files)} batch files...")
for batch_file in tqdm(batch_files, desc="Processing batches"):
with open(batch_file, 'rb') as f:
batch_sequences = pickle.load(f)
# Count frequencies in this batch
for words, _ in batch_sequences:
self.word_freq.update(words)
# Clear memory
del batch_sequences
gc.collect()
# Filter by frequency and limit size
filtered_words = [
word for word, freq in self.word_freq.items()
if freq >= self.min_freq
]
# Sort by frequency (most common first)
filtered_words.sort(key=lambda w: self.word_freq[w], reverse=True)
# Limit vocabulary size
if self.max_vocab_size:
filtered_words = filtered_words[:self.max_vocab_size - 2] # -2 for PAD and UNK
# Build mappings
for word in filtered_words:
idx = len(self.word2idx)
self.word2idx[word] = idx
self.idx2word[idx] = word
print(f"Vocabulary built with {len(self)} words")
print(f" Total unique words in corpus: {len(self.word_freq)}")
print(f" Words kept (freq >= {self.min_freq}): {len(filtered_words)}")
print(f" Coverage: {len(self) / len(self.word_freq) * 100:.2f}%")
def encode(self, words: List[str]) -> List[int]:
"""
Convert words to indices.
Args:
words: List of words
Returns:
List of indices
"""
return [self.word2idx.get(word, self.word2idx['<UNK>']) for word in words]
def decode(self, indices: List[int]) -> List[str]:
"""
Convert indices to words.
Args:
indices: List of indices
Returns:
List of words
"""
return [self.idx2word.get(idx, '<UNK>') for idx in indices]
def __len__(self):
return len(self.word2idx)
def save(self, path: str):
"""Save vocabulary to file."""
with open(path, 'wb') as f:
pickle.dump(self, f)
print(f"Vocabulary saved to {path}")
@staticmethod
def load(path: str) -> 'Vocabulary':
"""Load vocabulary from file."""
with open(path, 'rb') as f:
vocab = pickle.load(f)
print(f"Vocabulary loaded from {path}")
return vocab
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def reconstruct_text_with_punctuation(words: List[str], labels: List[int]) -> str:
"""
Reconstruct text with punctuation from words and labels.
Args:
words: List of words
labels: List of punctuation labels
Returns:
Text with punctuation marks
"""
result = []
for word, label in zip(words, labels):
result.append(word)
punct = LABEL_TO_PUNCT.get(label, '')
if punct:
result.append(punct)
return ' '.join(result)
def save_processed_data(sequences: List[Tuple[List[str], List[int]]],
filepath: str):
"""
Save processed sequences to file.
Args:
sequences: List of (words, labels) tuples
filepath: Path to save file
"""
with open(filepath, 'wb') as f:
pickle.dump(sequences, f)
print(f"Saved {len(sequences)} sequences to {filepath}")
def load_processed_data(filepath: str) -> List[Tuple[List[str], List[int]]]:
"""
Load processed sequences from file.
Args:
filepath: Path to load file
Returns:
List of (words, labels) tuples
"""
with open(filepath, 'rb') as f:
sequences = pickle.load(f)
print(f"Loaded {len(sequences)} sequences from {filepath}")
return sequences
class BiLSTM(nn.Module):
"""
BiLSTM for sequence labeling.
"""
def __init__(self, vocab_size, embedding_dim=300, hidden_dim=256,
num_layers=2, num_classes=7, dropout=0.3, pretrained_embeddings=None):
super(BiLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Embedding layer
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
if pretrained_embeddings is not None:
self.embedding.weight.data.copy_(pretrained_embeddings)
# BiLSTM
self.lstm = nn.LSTM(
embedding_dim, hidden_dim, num_layers=num_layers,
batch_first=True, bidirectional=True,
dropout=dropout if num_layers > 1 else 0
)
# Dropout
self.dropout = nn.Dropout(dropout)
# Output layer (per timestep)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, word_indices, lengths):
# Embedding
embedded = self.embedding(word_indices)
embedded = self.dropout(embedded)
# Pack sequence
packed = pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
)
# BiLSTM
packed_output, _ = self.lstm(packed)
lstm_output, _ = pad_packed_sequence(packed_output, batch_first=True)
lstm_output = self.dropout(lstm_output)
# Output for each timestep
logits = self.fc(lstm_output)
return logits
print("✓ Model architecture defined")
class PunctuationRestorer:
def __init__(self, repo_id="malkhuzanie/arabic-punctuation-checkpoints",
model_file="approach1_best_model_256_wf_fc_loss_smoth_none_v1.pth",
vocab_file="vocabulary.pkl"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Downloading vocabulary and model...")
vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_file)
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
# This hack ensures pickle can find 'Vocabulary' class in this script
# if it was pickled in a different namespace
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == 'Vocabulary':
return Vocabulary
return super().find_class(module, name)
with open(vocab_path, 'rb') as f:
try:
self.vocab = pickle.load(f)
except:
f.seek(0)
self.vocab = CustomUnpickler(f).load()
# Load Model
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
self.model = BiLSTM(
vocab_size=len(self.vocab.word2idx),
embedding_dim=300, hidden_dim=256, num_layers=2, num_classes=7
).to(self.device)
state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
self.model.load_state_dict(state_dict)
self.model.eval()
self.max_len = 256
def predict(self, text):
# Prepare Input
words, _ = extract_words_and_labels(text, preprocess=True)
if not words: return text
# Convert words to indices
indices = [self.vocab.word2idx.get(w, self.vocab.word2idx['<UNK>']) for w in words][:self.max_len]
input_tensor = torch.tensor([indices], dtype=torch.long).to(self.device)
lengths = torch.tensor([len(indices)], dtype=torch.long)
# Predict
with torch.no_grad():
logits = self.model(input_tensor, lengths)
preds = torch.argmax(logits, dim=-1).cpu().numpy()[0]
# Reconstruct
result = []
# Ensure we don't go out of bounds if words were truncated by max_len
limit = min(len(words), len(preds))
for word, label in zip(words[:limit], preds[:limit]):
punct = LABEL_TO_PUNCT.get(label, '')
result.append(word + punct)
return ' '.join(result)