import gradio as gr import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import numpy as np from huggingface_hub import hf_hub_download import io # Constants from the model card MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] INPUT_SIZE = 224 MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB class TomatoClassifier(nn.Module): """MobileNetV3 Small model for binary tomato classification""" def __init__(self, dropout=0.476): super().__init__() self.backbone = models.mobilenet_v3_small(pretrained=False) in_features = self.backbone.classifier[0].in_features self.backbone.classifier = nn.Sequential( nn.Linear(in_features, 1024), nn.Hardswish(), nn.Dropout(p=dropout), nn.Linear(1024, 2) # Binary classification ) def forward(self, x): return self.backbone(x) def load_model(): """Load the trained model from Hugging Face Hub""" print("Loading model...") # Download model weights - the file is named model_state.pt model_path = hf_hub_download( repo_id="kevinkyi/Homework2_NN", filename="model_state.pt" ) # The saved model is the full MobileNetV3 model, not wrapped in TomatoClassifier # Load the base model and modify classifier model = models.mobilenet_v3_small(weights=None) # Modify the classifier to match the saved model in_features = model.classifier[0].in_features model.classifier = nn.Sequential( nn.Linear(in_features, 1024), nn.Hardswish(), nn.Dropout(p=0.4761270681732692), nn.Linear(1024, 2) # Binary classification ) # Load the saved weights directly checkpoint = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint) model.eval() print("Model loaded successfully!") return model # Load model at startup model = load_model() # Define transforms eval_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(INPUT_SIZE), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def denormalize_image(tensor): """Convert normalized tensor back to displayable image""" mean = torch.tensor(MEAN).view(3, 1, 1) std = torch.tensor(STD).view(3, 1, 1) denorm = tensor * std + mean denorm = torch.clamp(denorm, 0, 1) return denorm def validate_image(image): """Validate image file""" if image is None: return False, "No image provided" try: if isinstance(image, np.ndarray): # Already loaded as numpy array return True, "Valid image" # Check file size if it's a file path if isinstance(image, str): import os file_size = os.path.getsize(image) if file_size > MAX_FILE_SIZE: return False, f"File too large ({file_size/(1024*1024):.1f}MB). Maximum size is 10MB." return True, "Valid image" except Exception as e: return False, f"Error validating image: {str(e)}" def predict_tomato( image, confidence_threshold=0.5, show_preprocessing=True ): """ Predict whether an image contains a tomato Args: image: Input image (PIL Image, numpy array, or file path string) confidence_threshold: Minimum confidence for positive prediction show_preprocessing: Whether to show the preprocessed image Returns: Tuple of (prediction_text, original_image, preprocessed_image, confidence_plot) """ try: # Validate image is_valid, message = validate_image(image) if not is_valid: return message, None, None, None # Convert to PIL Image based on input type if isinstance(image, np.ndarray): pil_image = Image.fromarray(image).convert('RGB') elif isinstance(image, str): # It's a file path pil_image = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): # Already a PIL Image pil_image = image.convert('RGB') else: pil_image = Image.open(image).convert('RGB') # Store original for display original_image = pil_image.copy() # Apply preprocessing preprocessed_tensor = eval_transform(pil_image).unsqueeze(0) # Create displayable preprocessed image preprocessed_display = denormalize_image(preprocessed_tensor.squeeze(0)) preprocessed_display = preprocessed_display.permute(1, 2, 0).numpy() # Make prediction with torch.no_grad(): outputs = model(preprocessed_tensor) probabilities = torch.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) not_tomato_prob = probabilities[0][0].item() tomato_prob = probabilities[0][1].item() # Format results prediction_class = "šŸ… TOMATO" if predicted.item() == 1 else "āŒ NOT TOMATO" confidence_value = confidence.item() # Create detailed result text result_text = f""" ## Prediction: {prediction_class} **Confidence:** {confidence_value:.2%} ### Class Probabilities: - Not Tomato: {not_tomato_prob:.2%} - Tomato: {tomato_prob:.2%} ### Model Details: - Architecture: MobileNetV3-Small - Input Size: {INPUT_SIZE}Ɨ{INPUT_SIZE} - Confidence Threshold: {confidence_threshold:.2%} """ if confidence_value < confidence_threshold: result_text += f"\nāš ļø **Low Confidence Warning**: Prediction confidence ({confidence_value:.2%}) is below threshold ({confidence_threshold:.2%})" # Create confidence visualization import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(8, 4)) categories = ['Not Tomato', 'Tomato'] probabilities_list = [not_tomato_prob, tomato_prob] colors = ['#ff6b6b', '#51cf66'] bars = ax.barh(categories, probabilities_list, color=colors, alpha=0.7) ax.set_xlim(0, 1) ax.set_xlabel('Probability', fontsize=12) ax.set_title('Classification Confidence', fontsize=14, fontweight='bold') ax.axvline(x=confidence_threshold, color='gray', linestyle='--', linewidth=2, label=f'Threshold ({confidence_threshold:.0%})') # Add percentage labels on bars for i, (bar, prob) in enumerate(zip(bars, probabilities_list)): ax.text(prob + 0.02, i, f'{prob:.1%}', va='center', fontsize=11, fontweight='bold') ax.legend() plt.tight_layout() # Return results return result_text, original_image, preprocessed_display if show_preprocessing else None, fig except Exception as e: error_msg = f"āŒ **Error during prediction:**\n\n{str(e)}" return error_msg, None, None, None # Create Gradio interface with gr.Blocks(title="šŸ… Tomato Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # šŸ… Tomato vs Not-Tomato Classifier This application uses a **MobileNetV3-Small** neural network trained with AutoML to classify images as tomato or not-tomato. ### How to Use: 1. Upload an image (PNG/JPG, max 10MB) or use your webcam 2. Adjust the confidence threshold if needed 3. Click "Classify Image" to see results 4. View the original image, preprocessed input, and confidence scores **Model Info:** Trained on 30 images with 83% test accuracy | Binary classification (0=not_tomato, 1=tomato) """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### šŸ“¤ Input") image_input = gr.Image( label="Upload Image", type="pil", sources=["upload", "webcam", "clipboard"], height=300 ) with gr.Accordion("āš™ļø Advanced Settings", open=False): confidence_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold", info="Minimum confidence required for classification" ) show_preprocess = gr.Checkbox( label="Show Preprocessed Image", value=True, info="Display how the model sees the image after preprocessing" ) classify_btn = gr.Button("šŸ” Classify Image", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### šŸ“Š Results") result_text = gr.Markdown(label="Prediction") confidence_plot = gr.Plot(label="Confidence Scores") with gr.Row(): with gr.Column(): original_output = gr.Image(label="Original Image", type="pil") with gr.Column(): preprocessed_output = gr.Image(label="Preprocessed Image (Model Input)", type="numpy") # Example images gr.Markdown("### šŸŽÆ Try These Examples") gr.Markdown("*Click on an example to load it*") gr.Examples( examples=[ ["examples/tomato1.jpg", 0.5, True], ["examples/not_tomato1.jpg", 0.5, True], ["examples/tomato2.jpg", 0.5, True], ], inputs=[image_input, confidence_slider, show_preprocess], fn=predict_tomato ) # Add information section with gr.Accordion("ā„¹ļø Model Information", open=False): gr.Markdown(""" ### Architecture - **Base Model:** MobileNetV3-Small (pretrained, fine-tuned) - **Dropout:** 0.476 - **Optimizer:** AdamW - **Learning Rate:** 1.186e-05 - **Input Resolution:** 224Ɨ224 pixels ### Preprocessing - **Normalization:** ImageNet mean/std - **Mean:** [0.485, 0.456, 0.406] - **Std:** [0.229, 0.224, 0.225] - **Transforms:** Resize(256) → CenterCrop(224) → Normalize ### Performance - **Test Accuracy:** 83% - **Test F1 Score:** 0.80 - **Training Data:** ~30 images (very small dataset) ### Limitations - Small training dataset may lead to overfitting - Performance may degrade on out-of-distribution images - Sensitive to lighting and background variations - Not suitable for production use (educational project) """) with gr.Accordion("āš ļø Known Failure Modes", open=False): gr.Markdown(""" This model may struggle with: - Cartoon or illustrated tomatoes - Extreme angles or unusual perspectives - Heavy shadows or overexposure - Images with multiple food items - Cherry tomatoes or unusual tomato varieties - Processed tomato products (sauce, soup, etc.) """) # Connect the button classify_btn.click( fn=predict_tomato, inputs=[image_input, confidence_slider, show_preprocess], outputs=[result_text, original_output, preprocessed_output, confidence_plot] ) if __name__ == "__main__": demo.launch()