Spaces:
Runtime error
Runtime error
File size: 2,061 Bytes
d65c2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
DataCollatorForSeq2Seq,
)
import torch
# 1. Load dataset
dataset = load_dataset("rohitsaxena/MovieSum")
# Rename columns if needed
dataset = dataset.rename_columns({"script": "input_text", "summary": "target_text"})
# 2. Load model and tokenizer
model_checkpoint = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# 3. Preprocessing
def preprocess_function(examples):
inputs = tokenizer(
examples["input_text"],
max_length=1024,
padding="max_length",
truncation=True,
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples["target_text"],
max_length=128,
padding="max_length",
truncation=True,
)
inputs["labels"] = labels["input_ids"]
return inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 4. Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./film-script-summarizer",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=2,
push_to_hub=True,
hub_model_id="BhavyaSamhithaMallineni/FilmScriptSummarizer",
hub_strategy="every_save",
logging_dir="./logs",
logging_steps=50,
fp16=torch.cuda.is_available(),
)
# 5. Trainer setup
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
# 6. Train and push to hub
trainer.train()
trainer.push_to_hub()
tokenizer.push_to_hub("BhavyaSamhithaMallineni/FilmScriptSummarizer")
|