|
|
|
|
|
""" |
|
|
Train YOLOv8 on the TypoRef historical document dataset. |
|
|
|
|
|
This script wraps the Ultralytics YOLO API with a simple command-line |
|
|
interface. It allows you to specify the dataset configuration file, |
|
|
model backbone, image size, number of epochs, batch size, project |
|
|
directory, and experiment name. Additional hyper-parameters can be |
|
|
passed via --hyp to override defaults in `configs/hyp_augment.yaml`. |
|
|
""" |
|
|
import argparse |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser(description="Train YOLOv8 for TypoRef document detection") |
|
|
p.add_argument("--data", type=str, default="configs/ornaments.yaml", help="Path to data config") |
|
|
p.add_argument("--model", type=str, default="yolov8s.pt", help="YOLOv8 backbone model") |
|
|
p.add_argument("--imgsz", type=int, default=1024, help="Input image size") |
|
|
p.add_argument("--epochs", type=int, default=100, help="Number of training epochs") |
|
|
p.add_argument("--batch", type=int, default=8, help="Batch size") |
|
|
p.add_argument("--workers", type=int, default=8, help="Number of dataloader workers") |
|
|
p.add_argument("--project", type=str, default="runs/typoref", help="Project directory") |
|
|
p.add_argument("--name", type=str, default="exp", help="Experiment name") |
|
|
p.add_argument("--hyp", type=str, default="configs/hyp_augment.yaml", help="Hyper-parameter file") |
|
|
p.add_argument("--patience", type=int, default=30, help="Early stopping patience") |
|
|
p.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") |
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
model = YOLO(args.model) |
|
|
results = model.train( |
|
|
data=args.data, |
|
|
imgsz=args.imgsz, |
|
|
epochs=args.epochs, |
|
|
batch=args.batch, |
|
|
workers=args.workers, |
|
|
project=args.project, |
|
|
name=args.name, |
|
|
cache=True, |
|
|
amp=True, |
|
|
deterministic=True, |
|
|
patience=args.patience, |
|
|
seed=args.seed, |
|
|
cfg=args.hyp, |
|
|
) |
|
|
print(results) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |