Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import toml | |
| import torch | |
| from PIL import Image | |
| from torch import nn | |
| from torchvision import transforms | |
| import net | |
| from function import * | |
| cfg = toml.load("config.toml") # static variables | |
| # Setup device | |
| if torch.cuda.is_available() and cfg["use_cuda"]: | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| # Load pretrained models | |
| decoder = net.decoder | |
| vgg = net.vgg | |
| decoder.eval() | |
| vgg.eval() | |
| decoder.load_state_dict(torch.load(cfg["decoder_weight"])) | |
| vgg.load_state_dict(torch.load(cfg["vgg_weight"])) | |
| vgg = nn.Sequential(*list(vgg.children())[:31]) | |
| vgg = vgg.to(device) | |
| decoder = decoder.to(device) | |
| def transform(img, size, crop): | |
| transform_list = [] | |
| if size > 0: | |
| transform_list.append(transforms.Resize(size)) | |
| if crop: | |
| transform_list.append(transforms.CenterCrop(size)) | |
| transform_list.append(transforms.ToTensor()) | |
| transform = transforms.Compose(transform_list) | |
| return transform(img) | |
| def style_transfer(content, style, style_type, alpha, keep_resolution): | |
| """Stylize function""" | |
| style_type = style_type.lower() | |
| # Step 1: convert image to PyTorch Tensor | |
| if keep_resolution: | |
| style = style.resize(content.size, Image.ANTIALIAS) | |
| if style_type == "efdm" and not keep_resolution: | |
| content = transform(content, cfg["content_size"], cfg["crop"]) | |
| style = transform(style, cfg["style_size"], cfg["crop"]) | |
| else: | |
| content = transform(content, -1, False) | |
| style = transform(style, -1, False) | |
| content = content.to(device).unsqueeze(0) | |
| style = style.to(device).unsqueeze(0) | |
| # Step 2: extract content feature and style feature | |
| content_feat = vgg(content) | |
| style_feat = vgg(style) | |
| # Step 3: perform style transfer | |
| transfer = { | |
| "adain": adaptive_instance_normalization, | |
| "adamean": adaptive_mean_normalization, | |
| "adastd": adaptive_std_normalization, | |
| "efdm": exact_feature_distribution_matching, | |
| "hm": histogram_matching, | |
| }[style_type] | |
| feat = transfer(content_feat, style_feat) | |
| # Step 4: content-style trade-off | |
| feat = feat * alpha + content_feat * (1 - alpha) | |
| # Step 5: decode to image | |
| output = decoder(feat).cpu().squeeze(0).clamp_(0, 1) | |
| output = transforms.ToPILImage()(output) | |
| if torch.cuda.is_available(): | |
| torch.cuda.ipc_collect() | |
| torch.cuda.empty_cache() | |
| return output | |
| # Add image examples | |
| example_img_pairs = { | |
| "examples/content/sailboat.jpg": "examples/style/sketch.png", | |
| "examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg", | |
| "examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg", | |
| "examples/content/paris.jpeg": "examples/style/vangogh.jpeg", | |
| "examples/content/cornell.jpg": "examples/style/asheville.jpg", | |
| } | |
| # Customize interface | |
| title = "Style Transfer with EFDM" | |
| description = """ | |
| Gradio demo for neural style transfer using exact feature distribution matching | |
| """ | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>" | |
| content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil") | |
| style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil") | |
| style_type = gr.inputs.Radio( | |
| ["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method" | |
| ) | |
| alpha_selector = gr.inputs.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off" | |
| ) | |
| keep_resolution = gr.inputs.Checkbox( | |
| default=True, label="Keep content image resolution" | |
| ) | |
| iface = gr.Interface( | |
| fn=style_transfer, | |
| inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution], | |
| outputs=["image"], | |
| title=title, | |
| description=description, | |
| article=article, | |
| theme="huggingface", | |
| examples=[ | |
| [content, style, "EFDM", 1.0, True] | |
| for content, style in example_img_pairs.items() | |
| ], | |
| ) | |
| iface.launch(debug=False, enable_queue=True) | |