TrafficSignDetector / check_model.py
VietCat's picture
fix bug
fb70692
#!/usr/bin/env python3
"""
Script to verify if the model has been trained with actual weights.
"""
import torch
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Download model
repo_id = 'VietCat/GTSRB-Model'
file_path = 'models/GTSRB.pt'
local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path)
print(f"Model path: {local_model_path}\n")
# Load model
model = YOLO(local_model_path)
print("="*80)
print("MODEL WEIGHTS ANALYSIS")
print("="*80)
# Check model layers and weights
print("\nChecking model weights...")
total_params = 0
zero_params = 0
trained_params = 0
for name, param in model.model.named_parameters():
param_count = param.numel()
total_params += param_count
# Check if weights are mostly zeros or random initialization
if torch.allclose(param, torch.zeros_like(param), atol=1e-6):
zero_params += param_count
status = "ZERO"
elif param.mean().item() != 0:
trained_params += param_count
status = "TRAINED"
else:
status = "RANDOM"
if param_count > 1000: # Only print large layers
print(f"{name:50s} | {param_count:10,} params | {status:10s} | mean: {param.mean().item():.6f}, std: {param.std().item():.6f}")
print(f"\n{'='*80}")
print(f"Total parameters: {total_params:,}")
print(f"Zero parameters: {zero_params:,} ({100*zero_params/total_params:.1f}%)")
print(f"Trained parameters: {trained_params:,} ({100*trained_params/total_params:.1f}%)")
# Check if this looks like trained weights
if zero_params / total_params > 0.5:
print("\n⚠️ WARNING: Model has >50% zero parameters - may not be properly trained!")
elif trained_params / total_params > 0.7:
print("\n✅ Model appears to be properly trained")
else:
print("\n❓ Uncertain - model may need verification")
print("="*80)