mahesh1209's picture
Update app.py
963df84 verified
import gradio as gr
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0, ResNet50
from tensorflow.keras.applications.efficientnet import preprocess_input as efficient_preprocess
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocess
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import numpy as np
# βœ… Config
IMG_SIZE = 224
MODEL_TYPE = "efficientnet" # or "resnet"
CLASS_NAMES = ['Cat', 'Dog', 'Panda'] # Replace with your actual classes
# βœ… Input Tensor
input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
# βœ… Load Base Model
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor) \
if MODEL_TYPE == "efficientnet" else \
ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
base_model.trainable = False
# βœ… Build Classifier
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(128, activation='relu')(x)
output = Dense(len(CLASS_NAMES), activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
# βœ… Compile & Load Weights (Optional)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
try:
model.load_weights("your_weights.h5")
except Exception as e:
print("⚠️ Could not load weights:", e)
# βœ… Prediction Function
def classify_image(img):
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
preprocess_fn = efficient_preprocess if MODEL_TYPE == "efficientnet" else resnet_preprocess
img_array = preprocess_fn(img_array)
preds = model.predict(img_array)[0]
return {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
# βœ… Gradio UI
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Transfer Learning Image Classifier",
description="Upload an image to classify using EfficientNet or ResNet."
)
demo.launch()