Spaces:
Build error
Build error
| import gradio as gr | |
| import keras | |
| from keras.models import load_model | |
| from tensorflow_addons.layers import InstanceNormalization | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import tensorflow as tf | |
| cust = {'InstanceNormalization': InstanceNormalization} | |
| model=load_model('g-cycleGAN-photo2monet-500images-epoch10_30_30_30_30_30_1000images_30_30_30.h5',cust) | |
| path = [['ex1.jpg'], ['ex2.jpg'], ['ex4.jpg'],['ex6.jpg'],['ex7.jpg'],['ex8.jpg'],['ex9.jpg'],['ex10.jpg'],['ex12.jpg'],['ex13.jpg']] | |
| # preprocess | |
| AUTOTUNE = tf.data.AUTOTUNE | |
| BUFFER_SIZE = 400 | |
| BATCH_SIZE = 1 | |
| IMG_WIDTH = 256 | |
| IMG_HEIGHT = 256 | |
| def resize(image,height,width): | |
| resized_image = tf.image.resize(image,[height,width],method = tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
| return resized_image | |
| def normalize(input_image): | |
| input_image = (input_image/127.5) - 1 | |
| return input_image | |
| def load(img_file): | |
| img = tf.io.read_file(img_file) | |
| img = tf.io.decode_jpeg(img) | |
| real_image = tf.cast(img,tf.float32) | |
| return real_image | |
| def load_image_test(image_file): | |
| re = load(image_file) | |
| re = resize(re,IMG_HEIGHT,IMG_WIDTH) | |
| re = normalize(re) | |
| return re | |
| def show_preds_image(image_path): | |
| A = load_image_test(image_path) | |
| A = np.expand_dims(A,axis=0) | |
| B = model(A) | |
| B = B[0] | |
| B = B * 0.5 + 0.5 | |
| B = B.numpy() | |
| return B | |
| inputs_image = [ | |
| gr.components.Image(shape=(256,256),type="filepath", label="Input Image"), | |
| ] | |
| outputs_image = [ | |
| gr.components.Image(shape=(256,256),type="numpy", label="Output Image").style(width=256, height=256), | |
| ] | |
| interface_image = gr.Interface( | |
| fn=show_preds_image, | |
| inputs=inputs_image, | |
| outputs=outputs_image, | |
| title="photo2monet", | |
| examples=path, | |
| cache_examples=False, | |
| ) | |
| gr.TabbedInterface( | |
| [interface_image], | |
| tab_names=['Image inference'] | |
| ).queue().launch() | |