Spaces:
Runtime error
Runtime error
File size: 2,166 Bytes
04765cb b5e7f9b 04765cb 963df84 04765cb 0d3ce57 04765cb b5e7f9b 963df84 0d3ce57 963df84 b5e7f9b 963df84 04765cb 0d3ce57 04765cb b5e7f9b 0d3ce57 04765cb 0d3ce57 963df84 0d3ce57 04765cb 0d3ce57 04765cb b5e7f9b 0d3ce57 04765cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0, ResNet50
from tensorflow.keras.applications.efficientnet import preprocess_input as efficient_preprocess
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocess
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import numpy as np
# β
Config
IMG_SIZE = 224
MODEL_TYPE = "efficientnet" # or "resnet"
CLASS_NAMES = ['Cat', 'Dog', 'Panda'] # Replace with your actual classes
# β
Input Tensor
input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
# β
Load Base Model
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor) \
if MODEL_TYPE == "efficientnet" else \
ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
base_model.trainable = False
# β
Build Classifier
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(128, activation='relu')(x)
output = Dense(len(CLASS_NAMES), activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
# β
Compile & Load Weights (Optional)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
try:
model.load_weights("your_weights.h5")
except Exception as e:
print("β οΈ Could not load weights:", e)
# β
Prediction Function
def classify_image(img):
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
preprocess_fn = efficient_preprocess if MODEL_TYPE == "efficientnet" else resnet_preprocess
img_array = preprocess_fn(img_array)
preds = model.predict(img_array)[0]
return {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
# β
Gradio UI
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Transfer Learning Image Classifier",
description="Upload an image to classify using EfficientNet or ResNet."
)
demo.launch()
|