dhrumii's picture
Upload model.py
9e816ce verified
import os
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Configuration
CONFIG = {
'img_size': 224,
'batch_size': 16, # Reduced batch size
'num_epochs': 30,
'learning_rate': 0.0001,
'patience': 7,
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'num_workers': 0, # Set to 0 to avoid multiprocessing issues
'model_save_path': 'best_trash_classifier.pth',
}
print(f"Using device: {CONFIG['device']}")
# Memory-Efficient Dataset Class
class TrashDatasetLazy(Dataset):
def __init__(self, dataset, indices, transform=None):
self.dataset = dataset
self.indices = indices
self.transform = transform
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
actual_idx = self.indices[idx]
item = self.dataset[actual_idx]
image = item['image']
label = item['label']
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(np.array(image))
# Convert to RGB if grayscale
if image.mode != 'RGB':
image = image.convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# Data Augmentation and Normalization
train_transform = transforms.Compose([
transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load Dataset (streaming mode)
print("\n" + "="*60)
print("LOADING DATASET")
print("="*60)
ds = load_dataset("rootstrap-org/waste-classifier", split="train")
print(f"Dataset loaded successfully!")
print(f"Total samples: {len(ds)}")
# Get class names
class_names = ds.features['label'].names
num_classes = len(class_names)
print(f"\nNumber of classes: {num_classes}")
print(f"Classes: {class_names}")
# Extract only labels for splitting (not images!)
labels = [item['label'] for item in ds]
# Check class distribution
unique, counts = np.unique(labels, return_counts=True)
print("\nClass Distribution:")
for cls_idx, count in zip(unique, counts):
print(f" {class_names[cls_idx]}: {count} samples ({count/len(labels)*100:.2f}%)")
# Split dataset: 70% train, 15% val, 15% test
print("\n" + "="*60)
print("SPLITTING DATASET")
print("="*60)
indices = np.arange(len(ds))
train_idx, temp_idx, y_train, y_temp = train_test_split(
indices, labels, test_size=0.3, random_state=42, stratify=labels
)
val_idx, test_idx, y_val, y_test = train_test_split(
temp_idx, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)
print(f"Train set: {len(train_idx)} samples")
print(f"Validation set: {len(val_idx)} samples")
print(f"Test set: {len(test_idx)} samples")
# Calculate class weights for handling imbalance
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(y_train),
y=y_train
)
class_weights = torch.FloatTensor(class_weights).to(CONFIG['device'])
print(f"\nClass weights (for imbalance): {class_weights.cpu().numpy()}")
# Create datasets and dataloaders
train_dataset = TrashDatasetLazy(ds, train_idx, transform=train_transform)
val_dataset = TrashDatasetLazy(ds, val_idx, transform=val_transform)
test_dataset = TrashDatasetLazy(ds, test_idx, transform=val_transform)
train_loader = DataLoader(
train_dataset, batch_size=CONFIG['batch_size'],
shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=CONFIG['batch_size'],
shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True
)
test_loader = DataLoader(
test_dataset, batch_size=CONFIG['batch_size'],
shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True
)
# Build Model using EfficientNetV2 (pretrained)
print("\n" + "="*60)
print("BUILDING MODEL")
print("="*60)
model = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
# Freeze early layers
for param in list(model.parameters())[:-20]:
param.requires_grad = False
# Modify classifier for our number of classes
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, num_classes)
)
model = model.to(CONFIG['device'])
print(f"Model: EfficientNetV2-S (pretrained)")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# Loss function with class weights and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Training and Validation Functions
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(loader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
epoch_loss = running_loss / total
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate_epoch(model, loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
pbar = tqdm(loader, desc='Validation')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
epoch_loss = running_loss / total
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# Training Loop
print("\n" + "="*60)
print("TRAINING MODEL")
print("="*60)
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': []
}
best_val_acc = 0.0
patience_counter = 0
for epoch in range(CONFIG['num_epochs']):
print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
print("-" * 60)
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device'])
val_loss, val_acc = validate_epoch(model, val_loader, criterion, CONFIG['device'])
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
scheduler.step(val_loss)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'class_names': class_names
}, CONFIG['model_save_path'])
print(f"βœ“ Model saved! (Val Acc: {val_acc:.2f}%)")
patience_counter = 0
else:
patience_counter += 1
print(f"No improvement ({patience_counter}/{CONFIG['patience']})")
if patience_counter >= CONFIG['patience']:
print("\nEarly stopping triggered!")
break
# Plot Training History
print("\n" + "="*60)
print("SAVING TRAINING GRAPHS")
print("="*60)
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
print("βœ“ Training graphs saved as 'training_history.png'")
# Load Best Model and Test
print("\n" + "="*60)
print("LOADING BEST MODEL AND TESTING")
print("="*60)
checkpoint = torch.load(CONFIG['model_save_path'])
model.load_state_dict(checkpoint['model_state_dict'])
print(f"βœ“ Loaded best model from epoch {checkpoint['epoch']+1}")
print(f" Best validation accuracy: {checkpoint['val_acc']:.2f}%")
# Test the model
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(test_loader, desc='Testing'):
images = images.to(CONFIG['device'])
outputs = model(images)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
# Calculate test accuracy
test_acc = 100. * np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
print(f"\n{'='*60}")
print(f"TEST SET ACCURACY: {test_acc:.2f}%")
print(f"{'='*60}")
# Classification Report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))
# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
print("\nβœ“ Confusion matrix saved as 'confusion_matrix.png'")
print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"βœ“ Best model saved: {CONFIG['model_save_path']}")
print(f"βœ“ Training history: training_history.png")
print(f"βœ“ Confusion matrix: confusion_matrix.png")
print(f"βœ“ Final test accuracy: {test_acc:.2f}%")