Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow.keras.applications import EfficientNetB0, ResNet50 | |
| from tensorflow.keras.applications.efficientnet import preprocess_input as efficient_preprocess | |
| from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocess | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D | |
| from tensorflow.keras.models import Model | |
| import numpy as np | |
| # β Config | |
| IMG_SIZE = 224 | |
| MODEL_TYPE = "efficientnet" # or "resnet" | |
| CLASS_NAMES = ['Cat', 'Dog', 'Panda'] # Replace with your actual classes | |
| # β Input Tensor | |
| input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3)) | |
| # β Load Base Model | |
| base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor) \ | |
| if MODEL_TYPE == "efficientnet" else \ | |
| ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor) | |
| base_model.trainable = False | |
| # β Build Classifier | |
| x = GlobalAveragePooling2D()(base_model.output) | |
| x = Dense(128, activation='relu')(x) | |
| output = Dense(len(CLASS_NAMES), activation='softmax')(x) | |
| model = Model(inputs=base_model.input, outputs=output) | |
| # β Compile & Load Weights (Optional) | |
| model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) | |
| try: | |
| model.load_weights("your_weights.h5") | |
| except Exception as e: | |
| print("β οΈ Could not load weights:", e) | |
| # β Prediction Function | |
| def classify_image(img): | |
| img = img.resize((IMG_SIZE, IMG_SIZE)) | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| preprocess_fn = efficient_preprocess if MODEL_TYPE == "efficientnet" else resnet_preprocess | |
| img_array = preprocess_fn(img_array) | |
| preds = model.predict(img_array)[0] | |
| return {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))} | |
| # β Gradio UI | |
| demo = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Transfer Learning Image Classifier", | |
| description="Upload an image to classify using EfficientNet or ResNet." | |
| ) | |
| demo.launch() | |