import gradio as gr import tensorflow as tf from tensorflow.keras.applications import EfficientNetB0, ResNet50 from tensorflow.keras.applications.efficientnet import preprocess_input as efficient_preprocess from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocess from tensorflow.keras.preprocessing import image from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D from tensorflow.keras.models import Model import numpy as np # ✅ Config IMG_SIZE = 224 MODEL_TYPE = "efficientnet" # or "resnet" CLASS_NAMES = ['Cat', 'Dog', 'Panda'] # Replace with your actual classes # ✅ Input Tensor input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3)) # ✅ Load Base Model base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor) \ if MODEL_TYPE == "efficientnet" else \ ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor) base_model.trainable = False # ✅ Build Classifier x = GlobalAveragePooling2D()(base_model.output) x = Dense(128, activation='relu')(x) output = Dense(len(CLASS_NAMES), activation='softmax')(x) model = Model(inputs=base_model.input, outputs=output) # ✅ Compile & Load Weights (Optional) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) try: model.load_weights("your_weights.h5") except Exception as e: print("⚠️ Could not load weights:", e) # ✅ Prediction Function def classify_image(img): img = img.resize((IMG_SIZE, IMG_SIZE)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) preprocess_fn = efficient_preprocess if MODEL_TYPE == "efficientnet" else resnet_preprocess img_array = preprocess_fn(img_array) preds = model.predict(img_array)[0] return {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))} # ✅ Gradio UI demo = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Transfer Learning Image Classifier", description="Upload an image to classify using EfficientNet or ResNet." ) demo.launch()