martinbadrous's picture
Upload 14 files
37aeb10 verified
#!/usr/bin/env python3
"""
Train YOLOv8 on the TypoRef historical document dataset.
This script wraps the Ultralytics YOLO API with a simple command-line
interface. It allows you to specify the dataset configuration file,
model backbone, image size, number of epochs, batch size, project
directory, and experiment name. Additional hyper-parameters can be
passed via --hyp to override defaults in `configs/hyp_augment.yaml`.
"""
import argparse
from ultralytics import YOLO
def parse_args():
p = argparse.ArgumentParser(description="Train YOLOv8 for TypoRef document detection")
p.add_argument("--data", type=str, default="configs/ornaments.yaml", help="Path to data config")
p.add_argument("--model", type=str, default="yolov8s.pt", help="YOLOv8 backbone model")
p.add_argument("--imgsz", type=int, default=1024, help="Input image size")
p.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
p.add_argument("--batch", type=int, default=8, help="Batch size")
p.add_argument("--workers", type=int, default=8, help="Number of dataloader workers")
p.add_argument("--project", type=str, default="runs/typoref", help="Project directory")
p.add_argument("--name", type=str, default="exp", help="Experiment name")
p.add_argument("--hyp", type=str, default="configs/hyp_augment.yaml", help="Hyper-parameter file")
p.add_argument("--patience", type=int, default=30, help="Early stopping patience")
p.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
return p.parse_args()
def main():
args = parse_args()
model = YOLO(args.model)
results = model.train(
data=args.data,
imgsz=args.imgsz,
epochs=args.epochs,
batch=args.batch,
workers=args.workers,
project=args.project,
name=args.name,
cache=True,
amp=True,
deterministic=True,
patience=args.patience,
seed=args.seed,
cfg=args.hyp,
)
print(results)
if __name__ == "__main__":
main()