Chobola
commited on
Commit
·
27ec26c
1
Parent(s):
f76940a
colie import
Browse files
app.py
CHANGED
|
@@ -1,7 +1,140 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
|
| 8 |
+
# Assuming the other files (utils.py, loss.py, siren.py, color.py, filter.py) are in the same directory,
|
| 9 |
+
# and have been modified to remove all .cuda() calls and any GPU-specific code.
|
| 10 |
+
# For example, in utils.py: remove .cuda() from tensors, models, etc.
|
| 11 |
+
# Similarly for others. Ensure everything runs on CPU with torch.device('cpu') if needed.
|
| 12 |
|
| 13 |
+
from utils import get_image, get_v_component, replace_v_component, interpolate_image, get_coords, get_patches, filter_up
|
| 14 |
+
from loss import L_exp, L_TV
|
| 15 |
+
from siren import INF
|
| 16 |
+
from color import rgb2hsv_torch, hsv2rgb_torch
|
| 17 |
+
|
| 18 |
+
# Note: get_image is modified to take a PIL Image instead of path.
|
| 19 |
+
# Add this function or modify utils.py accordingly.
|
| 20 |
+
def get_image_from_pil(pil_image):
|
| 21 |
+
image = torch.from_numpy(np.array(pil_image)).float()
|
| 22 |
+
image = image / torch.max(image)
|
| 23 |
+
image = torch.movedim(image, -1, 0).unsqueeze(0) # No .cuda()
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
# The enhancement function
|
| 27 |
+
def enhance_image(input_image, down_size, epochs, window, L, alpha, beta, gamma, delta):
|
| 28 |
+
if input_image is None:
|
| 29 |
+
raise gr.Error("Please upload an image.")
|
| 30 |
+
|
| 31 |
+
# Process the image
|
| 32 |
+
img_rgb = get_image_from_pil(input_image)
|
| 33 |
+
img_hsv = rgb2hsv_torch(img_rgb)
|
| 34 |
+
|
| 35 |
+
img_v = get_v_component(img_hsv)
|
| 36 |
+
img_v_lr = interpolate_image(img_v, down_size, down_size)
|
| 37 |
+
coords = get_coords(down_size, down_size)
|
| 38 |
+
patches = get_patches(img_v_lr, window)
|
| 39 |
+
|
| 40 |
+
img_siren = INF(patch_dim=window**2, num_layers=4, hidden_dim=256, add_layer=2)
|
| 41 |
+
# No .cuda()
|
| 42 |
+
|
| 43 |
+
optimizer = torch.optim.Adam(img_siren.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=3e-4)
|
| 44 |
+
|
| 45 |
+
l_exp = L_exp(16, L)
|
| 46 |
+
l_TV = L_TV()
|
| 47 |
+
|
| 48 |
+
for epoch in range(epochs):
|
| 49 |
+
img_siren.train()
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
|
| 52 |
+
illu_res_lr = img_siren(patches, coords)
|
| 53 |
+
illu_res_lr = illu_res_lr.view(1, 1, down_size, down_size)
|
| 54 |
+
illu_lr = illu_res_lr + img_v_lr
|
| 55 |
+
|
| 56 |
+
img_v_fixed_lr = (img_v_lr) / (illu_lr + 1e-4)
|
| 57 |
+
|
| 58 |
+
loss_spa = torch.mean(torch.abs(torch.pow(illu_lr - img_v_lr, 2)))
|
| 59 |
+
loss_tv = l_TV(illu_lr)
|
| 60 |
+
loss_exp = torch.mean(l_exp(illu_lr))
|
| 61 |
+
loss_sparsity = torch.mean(img_v_fixed_lr)
|
| 62 |
+
|
| 63 |
+
loss = loss_spa * alpha + loss_tv * beta + loss_exp * gamma + loss_sparsity * delta
|
| 64 |
+
loss.backward()
|
| 65 |
+
optimizer.step()
|
| 66 |
+
|
| 67 |
+
img_v_fixed = filter_up(img_v_lr, img_v_fixed_lr, img_v)
|
| 68 |
+
img_hsv_fixed = replace_v_component(img_hsv, img_v_fixed)
|
| 69 |
+
img_rgb_fixed = hsv2rgb_torch(img_hsv_fixed)
|
| 70 |
+
img_rgb_fixed = img_rgb_fixed / torch.max(img_rgb_fixed)
|
| 71 |
+
|
| 72 |
+
enhanced_np = (torch.movedim(img_rgb_fixed, 1, -1)[0].detach().cpu().numpy() * 255).astype(np.uint8)
|
| 73 |
+
enhanced_pil = Image.fromarray(enhanced_np)
|
| 74 |
+
|
| 75 |
+
# Save to temp file for download
|
| 76 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
|
| 77 |
+
enhanced_pil.save(tmp_file.name)
|
| 78 |
+
download_path = tmp_file.name
|
| 79 |
+
|
| 80 |
+
return enhanced_pil, download_path
|
| 81 |
+
|
| 82 |
+
# Description from README
|
| 83 |
+
description = """
|
| 84 |
+
# CoLIE: Fast Context-Based Low-Light Image Enhancement via Neural Implicit Representations
|
| 85 |
+
|
| 86 |
+
**Authors:** Tomáš Chobola*, Yu Liu, Hanyi Zhang, Julia A. Schnabel, Tingying Peng*
|
| 87 |
+
*Corresponding authors*
|
| 88 |
+
**Affiliations:** Technical University of Munich, Helmholtz AI, King’s College London
|
| 89 |
+
|
| 90 |
+
Accepted to ECCV 2024.
|
| 91 |
+
|
| 92 |
+
## Overview
|
| 93 |
+
- **Challenges with Current Methods:** Existing deep learning methods for low-light image enhancement struggle with high-resolution images, and they often fail to meet practical visual perception needs in diverse, unseen scenarios.
|
| 94 |
+
- **Introduction of CoLIE:** CoLIE (Context-Based Low-Light Image Enhancement) is a novel approach for enhancing low-light images. It works by mapping 2D coordinates of underexposed images to their illumination components, conditioned on local context.
|
| 95 |
+
- **Methodology:** The method utilizes HSV color space for image reconstruction. It employs an implicit neural function along with an embedded guided filter to further reduce computational overhead.
|
| 96 |
+
- **Innovations in Training:** CoLIE introduces a single image-based training loss function. This function aims to improve the model's adaptability across various scenes, enhancing its practical applicability.
|
| 97 |
+
|
| 98 |
+
Upload a low-light image on the left, adjust hyperparameters, and click 'Enhance' to see the result on the right.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# Gradio interface
|
| 102 |
+
with gr.Blocks(title="CoLIE - Low-Light Image Enhancement") as demo:
|
| 103 |
+
gr.Markdown(description)
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column():
|
| 107 |
+
input_image = gr.Image(type="pil", label="Upload Low-Light Image")
|
| 108 |
+
down_size = gr.Slider(minimum=64, maximum=512, step=32, value=256, label="Down Size")
|
| 109 |
+
epochs = gr.Slider(minimum=10, maximum=500, step=1, value=100, label="Epochs")
|
| 110 |
+
window = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Window Size")
|
| 111 |
+
L = gr.Slider(minimum=0.1, maximum=1.0, step=0.01, value=0.5, label="L (Optimally-Intense Threshold)")
|
| 112 |
+
alpha = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Alpha (Fidelity Control)")
|
| 113 |
+
beta = gr.Slider(minimum=1.0, maximum=100.0, step=1.0, value=20.0, label="Beta (Illumination Smoothness)")
|
| 114 |
+
gamma = gr.Slider(minimum=1.0, maximum=50.0, step=1.0, value=8.0, label="Gamma (Exposure Control)")
|
| 115 |
+
delta = gr.Slider(minimum=1.0, maximum=50.0, step=1.0, value=5.0, label="Delta (Sparsity Level)")
|
| 116 |
+
enhance_btn = gr.Button("Enhance")
|
| 117 |
+
|
| 118 |
+
with gr.Column():
|
| 119 |
+
output_image = gr.Image(label="Enhanced Image")
|
| 120 |
+
output_download = gr.File(label="Download Output Image")
|
| 121 |
+
# Optional: Input download not necessary, as user uploaded it. Right-click on input image works too.
|
| 122 |
+
|
| 123 |
+
# Examples section (grid with one example)
|
| 124 |
+
# Assume you have an example image in the repo, e.g., "examples/low_light_example.png"
|
| 125 |
+
# For demo, placeholder. Replace with actual path.
|
| 126 |
+
examples = gr.Examples(
|
| 127 |
+
examples=[
|
| 128 |
+
["examples/low_light_example.png", 256, 100, 1, 0.5, 1.0, 20.0, 8.0, 5.0]
|
| 129 |
+
],
|
| 130 |
+
inputs=[input_image, down_size, epochs, window, L, alpha, beta, gamma, delta],
|
| 131 |
+
label="Examples (Click to load image and hyperparameters)"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
enhance_btn.click(
|
| 135 |
+
enhance_image,
|
| 136 |
+
inputs=[input_image, down_size, epochs, window, L, alpha, beta, gamma, delta],
|
| 137 |
+
outputs=[output_image, output_download]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
demo.launch()
|
colie.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from loss import *
|
| 3 |
+
from siren import INF
|
| 4 |
+
from color import rgb2hsv_torch, hsv2rgb_torch
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser(description='CoLIE')
|
| 14 |
+
parser.add_argument('--input_folder', type=str, default='input/')
|
| 15 |
+
parser.add_argument('--output_folder', type=str, default='output/')
|
| 16 |
+
parser.add_argument('--down_size', type=int, default=256, help='downsampling size')
|
| 17 |
+
parser.add_argument('--epochs', type=int, default=100)
|
| 18 |
+
parser.add_argument('--window', type=int, default=1, help='context window size')
|
| 19 |
+
parser.add_argument('--L', type=float, default=0.5)
|
| 20 |
+
# loss fuction weigth parameters
|
| 21 |
+
parser.add_argument('--alpha', type=float, required=True)
|
| 22 |
+
parser.add_argument('--beta', type=float, required=True)
|
| 23 |
+
parser.add_argument('--gamma', type=float, required=True)
|
| 24 |
+
parser.add_argument('--delta', type=float, required=True)
|
| 25 |
+
opt = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if not os.path.exists(opt.input_folder):
|
| 29 |
+
print('input folder: {} does not exist'.format(opt.input_folder))
|
| 30 |
+
exit()
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(opt.output_folder):
|
| 33 |
+
os.makedirs(opt.output_folder)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
print(' > running')
|
| 37 |
+
for PATH in tqdm(np.sort(os.listdir(opt.input_folder))):
|
| 38 |
+
img_rgb = get_image(os.path.join(opt.input_folder, PATH))
|
| 39 |
+
img_hsv = rgb2hsv_torch(img_rgb)
|
| 40 |
+
|
| 41 |
+
img_v = get_v_component(img_hsv)
|
| 42 |
+
img_v_lr = interpolate_image(img_v, opt.down_size, opt.down_size)
|
| 43 |
+
coords = get_coords(opt.down_size, opt.down_size)
|
| 44 |
+
patches = get_patches(img_v_lr, opt.window)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
img_siren = INF(patch_dim=opt.window**2, num_layers=4, hidden_dim=256, add_layer=2)
|
| 48 |
+
|
| 49 |
+
optimizer = torch.optim.Adam(img_siren.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=3e-4)
|
| 50 |
+
|
| 51 |
+
l_exp = L_exp(16,opt.L)
|
| 52 |
+
l_TV = L_TV()
|
| 53 |
+
|
| 54 |
+
for epoch in range(opt.epochs):
|
| 55 |
+
img_siren.train()
|
| 56 |
+
optimizer.zero_grad()
|
| 57 |
+
|
| 58 |
+
illu_res_lr = img_siren(patches, coords)
|
| 59 |
+
illu_res_lr = illu_res_lr.view(1,1,opt.down_size,opt.down_size)
|
| 60 |
+
illu_lr = illu_res_lr + img_v_lr
|
| 61 |
+
|
| 62 |
+
img_v_fixed_lr = (img_v_lr) / (illu_lr + 1e-4)
|
| 63 |
+
|
| 64 |
+
loss_spa = torch.mean(torch.abs(torch.pow(illu_lr - img_v_lr, 2)))
|
| 65 |
+
loss_tv = l_TV(illu_lr)
|
| 66 |
+
loss_exp = torch.mean(l_exp(illu_lr))
|
| 67 |
+
loss_sparsity = torch.mean(img_v_fixed_lr)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
loss = loss_spa * opt.alpha + loss_tv * opt.beta + loss_exp * opt.gamma + loss_sparsity * opt.delta
|
| 71 |
+
loss.backward()
|
| 72 |
+
optimizer.step()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
img_v_fixed = filter_up(img_v_lr, img_v_fixed_lr, img_v)
|
| 76 |
+
img_hsv_fixed = replace_v_component(img_hsv, img_v_fixed)
|
| 77 |
+
img_rgb_fixed = hsv2rgb_torch(img_hsv_fixed)
|
| 78 |
+
img_rgb_fixed = img_rgb_fixed / torch.max(img_rgb_fixed)
|
| 79 |
+
|
| 80 |
+
Image.fromarray(
|
| 81 |
+
(torch.movedim(img_rgb_fixed,1,-1)[0].detach().cpu().numpy() * 255).astype(np.uint8)
|
| 82 |
+
).save(os.path.join(opt.output_folder, PATH))
|
| 83 |
+
|
| 84 |
+
print(' > reconstruction done')
|
color.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def rgb2hsv_torch(rgb: torch.Tensor) -> torch.Tensor:
|
| 4 |
+
cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
|
| 5 |
+
cmin = torch.min(rgb, dim=1, keepdim=True)[0]
|
| 6 |
+
delta = cmax - cmin
|
| 7 |
+
hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
|
| 8 |
+
cmax_idx[delta == 0] = 3
|
| 9 |
+
hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
|
| 10 |
+
hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
|
| 11 |
+
hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
|
| 12 |
+
hsv_h[cmax_idx == 3] = 0.
|
| 13 |
+
hsv_h /= 6.
|
| 14 |
+
hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
|
| 15 |
+
hsv_v = cmax
|
| 16 |
+
return torch.cat([hsv_h, hsv_s, hsv_v], dim=1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def hsv2rgb_torch(hsv: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
hsv_h, hsv_s, hsv_l = hsv[:, 0:1], hsv[:, 1:2], hsv[:, 2:3]
|
| 21 |
+
_c = hsv_l * hsv_s
|
| 22 |
+
_x = _c * (- torch.abs(hsv_h * 6. % 2. - 1) + 1.)
|
| 23 |
+
_m = hsv_l - _c
|
| 24 |
+
_o = torch.zeros_like(_c)
|
| 25 |
+
idx = (hsv_h * 6.).type(torch.uint8)
|
| 26 |
+
idx = (idx % 6).expand(-1, 3, -1, -1)
|
| 27 |
+
rgb = torch.empty_like(hsv)
|
| 28 |
+
rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0]
|
| 29 |
+
rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1]
|
| 30 |
+
rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2]
|
| 31 |
+
rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3]
|
| 32 |
+
rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4]
|
| 33 |
+
rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5]
|
| 34 |
+
rgb += _m
|
| 35 |
+
return rgb
|
filter.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
def diff_x(input, r):
|
| 7 |
+
assert input.dim() == 4
|
| 8 |
+
|
| 9 |
+
left = input[:, :, r:2 * r + 1]
|
| 10 |
+
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
|
| 11 |
+
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
|
| 12 |
+
|
| 13 |
+
output = torch.cat([left, middle, right], dim=2)
|
| 14 |
+
|
| 15 |
+
return output
|
| 16 |
+
|
| 17 |
+
def diff_y(input, r):
|
| 18 |
+
assert input.dim() == 4
|
| 19 |
+
|
| 20 |
+
left = input[:, :, :, r:2 * r + 1]
|
| 21 |
+
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
|
| 22 |
+
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
|
| 23 |
+
|
| 24 |
+
output = torch.cat([left, middle, right], dim=3)
|
| 25 |
+
|
| 26 |
+
return output
|
| 27 |
+
|
| 28 |
+
class BoxFilter(nn.Module):
|
| 29 |
+
def __init__(self, r):
|
| 30 |
+
super(BoxFilter, self).__init__()
|
| 31 |
+
|
| 32 |
+
self.r = r
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
assert x.dim() == 4
|
| 36 |
+
|
| 37 |
+
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class FastGuidedFilter(nn.Module):
|
| 41 |
+
def __init__(self, r, eps=1e-8):
|
| 42 |
+
super(FastGuidedFilter, self).__init__()
|
| 43 |
+
|
| 44 |
+
self.r = r
|
| 45 |
+
self.eps = eps
|
| 46 |
+
self.boxfilter = BoxFilter(r)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def forward(self, lr_x, lr_y, hr_x):
|
| 50 |
+
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
|
| 51 |
+
n_lry, c_lry, h_lry, w_lry = lr_y.size()
|
| 52 |
+
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
|
| 53 |
+
|
| 54 |
+
assert n_lrx == n_lry and n_lry == n_hrx
|
| 55 |
+
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
|
| 56 |
+
assert h_lrx == h_lry and w_lrx == w_lry
|
| 57 |
+
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
|
| 58 |
+
|
| 59 |
+
## N
|
| 60 |
+
N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
|
| 61 |
+
|
| 62 |
+
## mean_x
|
| 63 |
+
mean_x = self.boxfilter(lr_x) / N
|
| 64 |
+
## mean_y
|
| 65 |
+
mean_y = self.boxfilter(lr_y) / N
|
| 66 |
+
## cov_xy
|
| 67 |
+
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
|
| 68 |
+
## var_x
|
| 69 |
+
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
|
| 70 |
+
|
| 71 |
+
## A
|
| 72 |
+
A = cov_xy / (var_x + self.eps)
|
| 73 |
+
## b
|
| 74 |
+
b = mean_y - A * mean_x
|
| 75 |
+
|
| 76 |
+
## mean_A; mean_b
|
| 77 |
+
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
| 78 |
+
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
| 79 |
+
|
| 80 |
+
return mean_A*hr_x+mean_b
|
loss.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class L_exp(nn.Module):
|
| 7 |
+
def __init__(self, patch_size, mean_val):
|
| 8 |
+
super(L_exp, self).__init__()
|
| 9 |
+
self.pool = nn.AvgPool2d(patch_size)
|
| 10 |
+
self.mean_val = mean_val
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
mean = self.pool(x) ** 0.5
|
| 14 |
+
d = torch.abs(torch.mean(torch.pow(mean - torch.FloatTensor([self.mean_val] ),2)))
|
| 15 |
+
return d
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class L_TV(nn.Module):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super(L_TV,self).__init__()
|
| 21 |
+
|
| 22 |
+
def forward(self,x):
|
| 23 |
+
batch_size = x.size()[0]
|
| 24 |
+
h_x = x.size()[2]
|
| 25 |
+
w_x = x.size()[3]
|
| 26 |
+
count_h = (x.size()[2] - 1) * x.size()[3]
|
| 27 |
+
count_w = x.size()[2] * (x.size()[3] - 1)
|
| 28 |
+
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
|
| 29 |
+
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
|
| 30 |
+
return 2*(h_tv/count_h+w_tv/count_w)/batch_size
|
| 31 |
+
|
siren.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SirenLayer(nn.Module):
|
| 7 |
+
def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.in_f = in_f
|
| 10 |
+
self.w0 = w0
|
| 11 |
+
self.linear = nn.Linear(in_f, out_f)
|
| 12 |
+
self.is_first = is_first
|
| 13 |
+
self.is_last = is_last
|
| 14 |
+
if not self.is_last:
|
| 15 |
+
self.init_weights()
|
| 16 |
+
|
| 17 |
+
def init_weights(self):
|
| 18 |
+
b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
self.linear.weight.uniform_(-b, b)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = self.linear(x)
|
| 24 |
+
return nn.Sigmoid()(x) if self.is_last else torch.sin(self.w0 * x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class INF(nn.Module):
|
| 28 |
+
def __init__(self, patch_dim, num_layers, hidden_dim, add_layer, weight_decay=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
'''
|
| 31 |
+
`add_layer` should be in range of [1, num_layers-2]
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
patch_layers = [SirenLayer(patch_dim, hidden_dim, is_first=True)]
|
| 35 |
+
spatial_layers = [SirenLayer(2, hidden_dim, is_first=True)]
|
| 36 |
+
output_layers = []
|
| 37 |
+
|
| 38 |
+
for _ in range(1, add_layer - 2):
|
| 39 |
+
patch_layers.append(SirenLayer(hidden_dim, hidden_dim))
|
| 40 |
+
spatial_layers.append(SirenLayer(hidden_dim, hidden_dim))
|
| 41 |
+
patch_layers.append(SirenLayer(hidden_dim, hidden_dim//2))
|
| 42 |
+
spatial_layers.append(SirenLayer(hidden_dim, hidden_dim//2))
|
| 43 |
+
|
| 44 |
+
for _ in range(add_layer, num_layers - 1):
|
| 45 |
+
output_layers.append(SirenLayer(hidden_dim, hidden_dim))
|
| 46 |
+
output_layers.append(SirenLayer(hidden_dim, 1, is_last=True))
|
| 47 |
+
|
| 48 |
+
self.patch_net = nn.Sequential(*patch_layers)
|
| 49 |
+
self.spatial_net = nn.Sequential(*spatial_layers)
|
| 50 |
+
self.output_net = nn.Sequential(*output_layers)
|
| 51 |
+
|
| 52 |
+
if not weight_decay:
|
| 53 |
+
weight_decay = [0.1, 0.0001, 0.001]
|
| 54 |
+
|
| 55 |
+
self.params = []
|
| 56 |
+
self.params += [{'params':self.spatial_net.parameters(),'weight_decay':weight_decay[0]}]
|
| 57 |
+
self.params += [{'params':self.patch_net.parameters(),'weight_decay':weight_decay[1]}]
|
| 58 |
+
self.params += [{'params':self.output_net.parameters(),'weight_decay':weight_decay[2]}]
|
| 59 |
+
|
| 60 |
+
def forward(self, patch, spatial):
|
| 61 |
+
return self.output_net(torch.cat((self.patch_net(patch), self.spatial_net(spatial)), -1))
|
utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from filter import FastGuidedFilter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_image(path):
|
| 11 |
+
"""
|
| 12 |
+
Reads and returns RGB image, (1,3,H,W).
|
| 13 |
+
"""
|
| 14 |
+
image = torch.from_numpy(np.array(Image.open(path))).float()
|
| 15 |
+
image = image / torch.max(image)
|
| 16 |
+
image = torch.movedim(image, -1, 0).unsqueeze(0)
|
| 17 |
+
return image
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_v_component(img_hsv):
|
| 21 |
+
"""
|
| 22 |
+
Assumes (1,3,H,W) HSV image.
|
| 23 |
+
"""
|
| 24 |
+
return img_hsv[:,-1].unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def replace_v_component(img_hsv, v_new):
|
| 28 |
+
"""
|
| 29 |
+
Replaces the V component of a HSV image (1,3,H,W).
|
| 30 |
+
"""
|
| 31 |
+
img_hsv[:,-1] = v_new
|
| 32 |
+
return img_hsv
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def interpolate_image(img, H, W):
|
| 36 |
+
"""
|
| 37 |
+
Reshapes the image based on new resolution.
|
| 38 |
+
"""
|
| 39 |
+
return F.interpolate(img, size=(H,W))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_coords(H, W):
|
| 43 |
+
"""
|
| 44 |
+
Creates a coordinates grid for INF.
|
| 45 |
+
"""
|
| 46 |
+
coords = np.dstack(np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W)))
|
| 47 |
+
coords = torch.from_numpy(coords).float()
|
| 48 |
+
return coords
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_patches(img, KERNEL_SIZE):
|
| 52 |
+
"""
|
| 53 |
+
Creates a tensor where the channel contains patch information.
|
| 54 |
+
"""
|
| 55 |
+
kernel = torch.zeros((KERNEL_SIZE ** 2, 1, KERNEL_SIZE, KERNEL_SIZE))
|
| 56 |
+
|
| 57 |
+
for i in range(KERNEL_SIZE):
|
| 58 |
+
for j in range(KERNEL_SIZE):
|
| 59 |
+
kernel[int(torch.sum(kernel).item()),0,i,j] = 1
|
| 60 |
+
|
| 61 |
+
pad = nn.ReflectionPad2d(KERNEL_SIZE//2)
|
| 62 |
+
im_padded = pad(img)
|
| 63 |
+
|
| 64 |
+
extracted = torch.nn.functional.conv2d(im_padded, kernel, padding=0).squeeze(0)
|
| 65 |
+
|
| 66 |
+
return torch.movedim(extracted, 0, -1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def filter_up(x_lr, y_lr, x_hr, r=1):
|
| 70 |
+
"""
|
| 71 |
+
Applies the guided filter to upscale the predicted image.
|
| 72 |
+
"""
|
| 73 |
+
guided_filter = FastGuidedFilter(r=r)
|
| 74 |
+
y_hr = guided_filter(x_lr, y_lr, x_hr)
|
| 75 |
+
y_hr = torch.clip(y_hr, 0, 1)
|
| 76 |
+
return y_hr
|