Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from datasets import load_dataset | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import seaborn as sns | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Set random seeds for reproducibility | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # Configuration | |
| CONFIG = { | |
| 'img_size': 224, | |
| 'batch_size': 16, # Reduced batch size | |
| 'num_epochs': 30, | |
| 'learning_rate': 0.0001, | |
| 'patience': 7, | |
| 'device': 'cuda' if torch.cuda.is_available() else 'cpu', | |
| 'num_workers': 0, # Set to 0 to avoid multiprocessing issues | |
| 'model_save_path': 'best_trash_classifier.pth', | |
| } | |
| print(f"Using device: {CONFIG['device']}") | |
| # Memory-Efficient Dataset Class | |
| class TrashDatasetLazy(Dataset): | |
| def __init__(self, dataset, indices, transform=None): | |
| self.dataset = dataset | |
| self.indices = indices | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.indices) | |
| def __getitem__(self, idx): | |
| actual_idx = self.indices[idx] | |
| item = self.dataset[actual_idx] | |
| image = item['image'] | |
| label = item['label'] | |
| # Convert to PIL Image if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(np.array(image)) | |
| # Convert to RGB if grayscale | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| # Data Augmentation and Normalization | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load Dataset (streaming mode) | |
| print("\n" + "="*60) | |
| print("LOADING DATASET") | |
| print("="*60) | |
| ds = load_dataset("rootstrap-org/waste-classifier", split="train") | |
| print(f"Dataset loaded successfully!") | |
| print(f"Total samples: {len(ds)}") | |
| # Get class names | |
| class_names = ds.features['label'].names | |
| num_classes = len(class_names) | |
| print(f"\nNumber of classes: {num_classes}") | |
| print(f"Classes: {class_names}") | |
| # Extract only labels for splitting (not images!) | |
| labels = [item['label'] for item in ds] | |
| # Check class distribution | |
| unique, counts = np.unique(labels, return_counts=True) | |
| print("\nClass Distribution:") | |
| for cls_idx, count in zip(unique, counts): | |
| print(f" {class_names[cls_idx]}: {count} samples ({count/len(labels)*100:.2f}%)") | |
| # Split dataset: 70% train, 15% val, 15% test | |
| print("\n" + "="*60) | |
| print("SPLITTING DATASET") | |
| print("="*60) | |
| indices = np.arange(len(ds)) | |
| train_idx, temp_idx, y_train, y_temp = train_test_split( | |
| indices, labels, test_size=0.3, random_state=42, stratify=labels | |
| ) | |
| val_idx, test_idx, y_val, y_test = train_test_split( | |
| temp_idx, y_temp, test_size=0.5, random_state=42, stratify=y_temp | |
| ) | |
| print(f"Train set: {len(train_idx)} samples") | |
| print(f"Validation set: {len(val_idx)} samples") | |
| print(f"Test set: {len(test_idx)} samples") | |
| # Calculate class weights for handling imbalance | |
| class_weights = compute_class_weight( | |
| class_weight='balanced', | |
| classes=np.unique(y_train), | |
| y=y_train | |
| ) | |
| class_weights = torch.FloatTensor(class_weights).to(CONFIG['device']) | |
| print(f"\nClass weights (for imbalance): {class_weights.cpu().numpy()}") | |
| # Create datasets and dataloaders | |
| train_dataset = TrashDatasetLazy(ds, train_idx, transform=train_transform) | |
| val_dataset = TrashDatasetLazy(ds, val_idx, transform=val_transform) | |
| test_dataset = TrashDatasetLazy(ds, test_idx, transform=val_transform) | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=CONFIG['batch_size'], | |
| shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, batch_size=CONFIG['batch_size'], | |
| shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, batch_size=CONFIG['batch_size'], | |
| shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True | |
| ) | |
| # Build Model using EfficientNetV2 (pretrained) | |
| print("\n" + "="*60) | |
| print("BUILDING MODEL") | |
| print("="*60) | |
| model = models.efficientnet_v2_s(weights='IMAGENET1K_V1') | |
| # Freeze early layers | |
| for param in list(model.parameters())[:-20]: | |
| param.requires_grad = False | |
| # Modify classifier for our number of classes | |
| num_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(num_features, 512), | |
| nn.ReLU(), | |
| nn.Dropout(p=0.3), | |
| nn.Linear(512, num_classes) | |
| ) | |
| model = model.to(CONFIG['device']) | |
| print(f"Model: EfficientNetV2-S (pretrained)") | |
| print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") | |
| # Loss function with class weights and optimizer | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=0.01) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode='min', factor=0.5, patience=3, verbose=True | |
| ) | |
| # Training and Validation Functions | |
| def train_epoch(model, loader, criterion, optimizer, device): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| pbar = tqdm(loader, desc='Training') | |
| for images, labels in pbar: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'}) | |
| epoch_loss = running_loss / total | |
| epoch_acc = 100. * correct / total | |
| return epoch_loss, epoch_acc | |
| def validate_epoch(model, loader, criterion, device): | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| pbar = tqdm(loader, desc='Validation') | |
| for images, labels in pbar: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| running_loss += loss.item() * images.size(0) | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'}) | |
| epoch_loss = running_loss / total | |
| epoch_acc = 100. * correct / total | |
| return epoch_loss, epoch_acc | |
| # Training Loop | |
| print("\n" + "="*60) | |
| print("TRAINING MODEL") | |
| print("="*60) | |
| history = { | |
| 'train_loss': [], 'train_acc': [], | |
| 'val_loss': [], 'val_acc': [] | |
| } | |
| best_val_acc = 0.0 | |
| patience_counter = 0 | |
| for epoch in range(CONFIG['num_epochs']): | |
| print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}") | |
| print("-" * 60) | |
| train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device']) | |
| val_loss, val_acc = validate_epoch(model, val_loader, criterion, CONFIG['device']) | |
| history['train_loss'].append(train_loss) | |
| history['train_acc'].append(train_acc) | |
| history['val_loss'].append(val_loss) | |
| history['val_acc'].append(val_acc) | |
| print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") | |
| print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") | |
| scheduler.step(val_loss) | |
| # Save best model | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_acc': val_acc, | |
| 'class_names': class_names | |
| }, CONFIG['model_save_path']) | |
| print(f"β Model saved! (Val Acc: {val_acc:.2f}%)") | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| print(f"No improvement ({patience_counter}/{CONFIG['patience']})") | |
| if patience_counter >= CONFIG['patience']: | |
| print("\nEarly stopping triggered!") | |
| break | |
| # Plot Training History | |
| print("\n" + "="*60) | |
| print("SAVING TRAINING GRAPHS") | |
| print("="*60) | |
| fig, axes = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Loss plot | |
| axes[0].plot(history['train_loss'], label='Train Loss', marker='o') | |
| axes[0].plot(history['val_loss'], label='Val Loss', marker='s') | |
| axes[0].set_xlabel('Epoch') | |
| axes[0].set_ylabel('Loss') | |
| axes[0].set_title('Training and Validation Loss') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| # Accuracy plot | |
| axes[1].plot(history['train_acc'], label='Train Acc', marker='o') | |
| axes[1].plot(history['val_acc'], label='Val Acc', marker='s') | |
| axes[1].set_xlabel('Epoch') | |
| axes[1].set_ylabel('Accuracy (%)') | |
| axes[1].set_title('Training and Validation Accuracy') | |
| axes[1].legend() | |
| axes[1].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig('training_history.png', dpi=300, bbox_inches='tight') | |
| print("β Training graphs saved as 'training_history.png'") | |
| # Load Best Model and Test | |
| print("\n" + "="*60) | |
| print("LOADING BEST MODEL AND TESTING") | |
| print("="*60) | |
| checkpoint = torch.load(CONFIG['model_save_path']) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| print(f"β Loaded best model from epoch {checkpoint['epoch']+1}") | |
| print(f" Best validation accuracy: {checkpoint['val_acc']:.2f}%") | |
| # Test the model | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for images, labels in tqdm(test_loader, desc='Testing'): | |
| images = images.to(CONFIG['device']) | |
| outputs = model(images) | |
| _, predicted = outputs.max(1) | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_labels.extend(labels.numpy()) | |
| # Calculate test accuracy | |
| test_acc = 100. * np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels) | |
| print(f"\n{'='*60}") | |
| print(f"TEST SET ACCURACY: {test_acc:.2f}%") | |
| print(f"{'='*60}") | |
| # Classification Report | |
| print("\n" + "="*60) | |
| print("CLASSIFICATION REPORT") | |
| print("="*60) | |
| print(classification_report(all_labels, all_preds, target_names=class_names, digits=4)) | |
| # Confusion Matrix | |
| cm = confusion_matrix(all_labels, all_preds) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, yticklabels=class_names) | |
| plt.title('Confusion Matrix - Test Set') | |
| plt.ylabel('True Label') | |
| plt.xlabel('Predicted Label') | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight') | |
| print("\nβ Confusion matrix saved as 'confusion_matrix.png'") | |
| print("\n" + "="*60) | |
| print("TRAINING COMPLETE!") | |
| print("="*60) | |
| print(f"β Best model saved: {CONFIG['model_save_path']}") | |
| print(f"β Training history: training_history.png") | |
| print(f"β Confusion matrix: confusion_matrix.png") | |
| print(f"β Final test accuracy: {test_acc:.2f}%") |