ReCasePunct 1 Flash
We introduce ReCasePunct 1 Flash, our first model capable of punctuation and casing restoration!
Given lowercase and non-punctated English text of any length (but it's not infinite length, as far as I tested), this model can predict punctuation and casing, and it's impressive!
It also runs very fast on CPU too!
Use cases could be for ASR tasks (some models give text without casing and punctuation, like on auto-generated subtitles for YouTube videos from 2023/2024/2025)
Limitations
This model was trained ONLY on English Tatoeba data and doesn't do well for other languages.
Also, it doesn't do perfectly sometimes (especially with proper nouns like "Minecraft").
We might train a multi-lingual and better ReCasePunct model next!
How To Run It
Code by Gemini 2.5 Flash:
from transformers import AutoTokenizer, AlbertConfig, AlbertModel
import torch
import torch.nn as nn
import re
import numpy as np
from safetensors.torch import load_file # Import safe_load for safetensors
from huggingface_hub import hf_hub_download # Import hf_hub_download
# Redefine the model class (must be the same as during training)
class AlbertForPunctuationAndCasing(nn.Module):
def __init__(self, config):
super().__init__()
self.num_punctuation_labels = config.num_punctuation_labels
self.num_casing_labels = config.num_casing_labels
# Initialize AlbertModel directly with the config provided
# This config should ideally reflect the true albert-large-v2 architecture
self.albert = AlbertModel(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.punctuation_classifier = nn.Linear(config.hidden_size, self.num_punctuation_labels)
self.casing_classifier = nn.Linear(config.hidden_size, self.num_casing_labels)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
casing_labels=None,
punctuation_labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else True
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
punctuation_logits = self.punctuation_classifier(sequence_output)
casing_logits = self.casing_classifier(sequence_output)
loss = None
if casing_labels is not None and punctuation_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
punctuation_loss = loss_fct(punctuation_logits.view(-1, self.num_punctuation_labels), punctuation_labels.view(-1))
casing_loss = loss_fct(casing_logits.view(-1, self.num_casing_labels), casing_labels.view(-1))
loss = punctuation_loss + casing_loss
if not return_dict:
output = (punctuation_logits, casing_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
result = {
"loss": loss,
"punctuation_logits": punctuation_logits,
"casing_logits": casing_logits,
}
if outputs.hidden_states is not None:
result["hidden_states"] = outputs.hidden_states
if outputs.attentions is not None:
result["attentions"] = outputs.attentions
return result
# --- Configuration and Mappings (must be the same as during training) ---
punctuation_labels = ['O', '.', ',', '?', '!', ';', ':', '-', '"', '(', ')', '/', '\\']
punctuation_to_id = {label: i for i, label in enumerate(punctuation_labels)}
id_to_punctuation = {i: label for i, label in enumerate(punctuation_labels)}
casing_labels = ['O', 'CAP', 'UPPER']
casing_to_id = {label: i for i, label in enumerate(casing_labels)}
id_to_casing = {i: label for i, label in enumerate(casing_labels)}
model_checkpoint = 'albert-large-v2'
# Define the Hugging Face repository ID
hf_repo_id = "MihaiPopa-1/ReCasePunct-1-Flash"
# Load tokenizer from Hugging Face Hub
tokenizer = AutoTokenizer.from_pretrained(hf_repo_id)
# --- CORRECTED MODEL CONFIG LOADING ---
# 1. Load the base ALBERT Large v2 configuration to get correct architecture defaults (like hidden_size)
config = AlbertConfig.from_pretrained(model_checkpoint)
# 2. Set the custom labels on this correctly sized config
config.num_punctuation_labels = len(punctuation_labels)
config.num_casing_labels = len(casing_labels)
# Instantiate the custom model with the corrected config
model = AlbertForPunctuationAndCasing(config)
# Download the model.safetensors file from the Hub
safetensors_path = hf_hub_download(repo_id=hf_repo_id, filename="model.safetensors")
# Load the full state dictionary into the custom model
model.load_state_dict(load_file(safetensors_path, device='cpu'))
model.eval()
def clean_text(text):
"""Removes punctuation and converts text to lowercase for the model input."""
text = text.lower()
text = re.sub(r'[\.,\?!\-;:"\(\)\[\]\{\}\/\\]', '', text) # Remove common punctuation
text = re.sub(r'\s+', ' ', text).strip() # Replace multiple spaces with single space
return text
def predict_punctuation_and_casing(text, model, tokenizer, id_to_punctuation, id_to_casing):
# Clean the input text similar to how training data was prepared
cleaned_text_input = clean_text(text)
words_in_cleaned_text = cleaned_text_input.split()
# Tokenize the input
tokenized_input = tokenizer(
cleaned_text_input,
return_offsets_mapping=True,
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt"
)
# Perform inference
with torch.no_grad():
outputs = model(
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask']
)
punctuation_logits = outputs['punctuation_logits'].squeeze(0).numpy()
casing_logits = outputs['casing_logits'].squeeze(0).numpy()
punctuation_predictions = np.argmax(punctuation_logits, axis=-1)
casing_predictions = np.argmax(casing_logits, axis=-1)
# Initialize output list for reconstructed sentence
reconstructed_text_parts = []
current_word_idx = 0
# Iterate over tokens and apply predictions
for token_idx, (token_start, token_end) in enumerate(tokenized_input['offset_mapping'].squeeze(0).numpy()):
if token_start == 0 and token_end == 0: # Skip special tokens like [CLS], [SEP]
continue
# Get the word from the original cleaned text (not subword)
# This requires careful alignment if a single word maps to multiple tokens
# and apply label to the last token of a word.
# Find the actual word from the input_text_single corresponding to this token
token_text = cleaned_text_input[token_start:token_end]
# Check if this token is the beginning of a word we care about
if current_word_idx < len(words_in_cleaned_text) and words_in_cleaned_text[current_word_idx].startswith(token_text):
word = words_in_cleaned_text[current_word_idx]
# Apply casing
casing_pred_label = id_to_casing[casing_predictions[token_idx]]
if casing_pred_label == 'CAP':
word = word.capitalize()
elif casing_pred_label == 'UPPER':
word = word.upper()
# Apply punctuation (only to the last subword token of a word)
# This is a heuristic and might need refinement for complex tokenizations
next_token_word_idx = -1
if token_idx + 1 < len(tokenized_input['offset_mapping'].squeeze(0).numpy()):
next_token_start, _ = tokenized_input['offset_mapping'].squeeze(0).numpy()[token_idx+1]
# Check if the next token starts after the current word ends in the cleaned_text_input
# or if the next token is a special token
if next_token_start >= token_end or (tokenized_input['input_ids'].squeeze(0)[token_idx+1].item() in [tokenizer.cls_token_id, tokenizer.sep_token_id]):
# This is likely the last token of the current word
punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]]
if punctuation_pred_label != 'O':
word += punctuation_pred_label
else:
# Last token in the sequence
punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]]
if punctuation_pred_label != 'O':
word += punctuation_pred_label
reconstructed_text_parts.append(word)
current_word_idx += 1
return ' '.join(reconstructed_text_parts).replace(' .', '.').replace(' ,', ',').replace(' ?', '?').replace(' !', '!').replace(' ;', ';').replace(' :', ':').replace(' -', '-').replace(' "', '"').replace('( ', '(').replace(' )', ')').replace(' /', '/').replace(' \\', '\\')
# --- Test Case for a single sentence ---
single_sample_sentence = "replace me by whatever sentence you like"
print(f"Original: {single_sample_sentence}")
print(f"Predicted: {predict_punctuation_and_casing(single_sample_sentence, model, tokenizer, id_to_punctuation, id_to_casing)}\n")
Should give: Replace me by whatever sentence you like.
Examples
| Original Sentence | Predicted Sentence |
|---|---|
| this is a test of punctuation prediction for english how are you doing today | This is a test of punctuation prediction for English. How are you doing today? |
| i love running this on t4 gpu and so for this goal we might make a better and more accurate model in the future | I love running this on T4 GPU and so, for this goal, we might make a better and more accurate model in the future. |
| so imagine this we live in a world with complex models yet this model does punctuation and casing prediction for english and it's very small at just only 18 million parameters | So, imagine this, we live in a world with complex models. Yet this model does punctuation and casing prediction for English, and it's very small at just only 18 million parameters. |
Evaluation Results
| Epoch | Training Loss | Validation Loss | Punctuation Accuracy | Casing Accuracy | Overall Accuracy |
|---|---|---|---|---|---|
| 1 | 0.072175 | 0.070485 | 0.642053 (64.21%) | 0.638791 (63.88%) | 0.640422 (64.04%) |
| 2 | 0.052846 | 0.063811 | 0.642343 (64.23%) | 0.640475 (64.05%) | 0.641409 (64.14%) |
| 3 | 0.031407 | 0.062892 | 0.640457 (64.05%) | 0.640098 (64.01%) | 0.640278 (64.03%) |
Model tree for MihaiPopa-1/ReCasePunct-1-Flash
Base model
albert/albert-large-v2