Chobola commited on
Commit
27ec26c
·
1 Parent(s): f76940a

colie import

Browse files
Files changed (7) hide show
  1. app.py +137 -4
  2. colie.py +84 -0
  3. color.py +35 -0
  4. filter.py +80 -0
  5. loss.py +31 -0
  6. siren.py +61 -0
  7. utils.py +76 -0
app.py CHANGED
@@ -1,7 +1,140 @@
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
  import gradio as gr
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
 
8
+ # Assuming the other files (utils.py, loss.py, siren.py, color.py, filter.py) are in the same directory,
9
+ # and have been modified to remove all .cuda() calls and any GPU-specific code.
10
+ # For example, in utils.py: remove .cuda() from tensors, models, etc.
11
+ # Similarly for others. Ensure everything runs on CPU with torch.device('cpu') if needed.
12
 
13
+ from utils import get_image, get_v_component, replace_v_component, interpolate_image, get_coords, get_patches, filter_up
14
+ from loss import L_exp, L_TV
15
+ from siren import INF
16
+ from color import rgb2hsv_torch, hsv2rgb_torch
17
+
18
+ # Note: get_image is modified to take a PIL Image instead of path.
19
+ # Add this function or modify utils.py accordingly.
20
+ def get_image_from_pil(pil_image):
21
+ image = torch.from_numpy(np.array(pil_image)).float()
22
+ image = image / torch.max(image)
23
+ image = torch.movedim(image, -1, 0).unsqueeze(0) # No .cuda()
24
+ return image
25
+
26
+ # The enhancement function
27
+ def enhance_image(input_image, down_size, epochs, window, L, alpha, beta, gamma, delta):
28
+ if input_image is None:
29
+ raise gr.Error("Please upload an image.")
30
+
31
+ # Process the image
32
+ img_rgb = get_image_from_pil(input_image)
33
+ img_hsv = rgb2hsv_torch(img_rgb)
34
+
35
+ img_v = get_v_component(img_hsv)
36
+ img_v_lr = interpolate_image(img_v, down_size, down_size)
37
+ coords = get_coords(down_size, down_size)
38
+ patches = get_patches(img_v_lr, window)
39
+
40
+ img_siren = INF(patch_dim=window**2, num_layers=4, hidden_dim=256, add_layer=2)
41
+ # No .cuda()
42
+
43
+ optimizer = torch.optim.Adam(img_siren.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=3e-4)
44
+
45
+ l_exp = L_exp(16, L)
46
+ l_TV = L_TV()
47
+
48
+ for epoch in range(epochs):
49
+ img_siren.train()
50
+ optimizer.zero_grad()
51
+
52
+ illu_res_lr = img_siren(patches, coords)
53
+ illu_res_lr = illu_res_lr.view(1, 1, down_size, down_size)
54
+ illu_lr = illu_res_lr + img_v_lr
55
+
56
+ img_v_fixed_lr = (img_v_lr) / (illu_lr + 1e-4)
57
+
58
+ loss_spa = torch.mean(torch.abs(torch.pow(illu_lr - img_v_lr, 2)))
59
+ loss_tv = l_TV(illu_lr)
60
+ loss_exp = torch.mean(l_exp(illu_lr))
61
+ loss_sparsity = torch.mean(img_v_fixed_lr)
62
+
63
+ loss = loss_spa * alpha + loss_tv * beta + loss_exp * gamma + loss_sparsity * delta
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ img_v_fixed = filter_up(img_v_lr, img_v_fixed_lr, img_v)
68
+ img_hsv_fixed = replace_v_component(img_hsv, img_v_fixed)
69
+ img_rgb_fixed = hsv2rgb_torch(img_hsv_fixed)
70
+ img_rgb_fixed = img_rgb_fixed / torch.max(img_rgb_fixed)
71
+
72
+ enhanced_np = (torch.movedim(img_rgb_fixed, 1, -1)[0].detach().cpu().numpy() * 255).astype(np.uint8)
73
+ enhanced_pil = Image.fromarray(enhanced_np)
74
+
75
+ # Save to temp file for download
76
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
77
+ enhanced_pil.save(tmp_file.name)
78
+ download_path = tmp_file.name
79
+
80
+ return enhanced_pil, download_path
81
+
82
+ # Description from README
83
+ description = """
84
+ # CoLIE: Fast Context-Based Low-Light Image Enhancement via Neural Implicit Representations
85
+
86
+ **Authors:** Tomáš Chobola*, Yu Liu, Hanyi Zhang, Julia A. Schnabel, Tingying Peng*
87
+ *Corresponding authors*
88
+ **Affiliations:** Technical University of Munich, Helmholtz AI, King’s College London
89
+
90
+ Accepted to ECCV 2024.
91
+
92
+ ## Overview
93
+ - **Challenges with Current Methods:** Existing deep learning methods for low-light image enhancement struggle with high-resolution images, and they often fail to meet practical visual perception needs in diverse, unseen scenarios.
94
+ - **Introduction of CoLIE:** CoLIE (Context-Based Low-Light Image Enhancement) is a novel approach for enhancing low-light images. It works by mapping 2D coordinates of underexposed images to their illumination components, conditioned on local context.
95
+ - **Methodology:** The method utilizes HSV color space for image reconstruction. It employs an implicit neural function along with an embedded guided filter to further reduce computational overhead.
96
+ - **Innovations in Training:** CoLIE introduces a single image-based training loss function. This function aims to improve the model's adaptability across various scenes, enhancing its practical applicability.
97
+
98
+ Upload a low-light image on the left, adjust hyperparameters, and click 'Enhance' to see the result on the right.
99
+ """
100
+
101
+ # Gradio interface
102
+ with gr.Blocks(title="CoLIE - Low-Light Image Enhancement") as demo:
103
+ gr.Markdown(description)
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ input_image = gr.Image(type="pil", label="Upload Low-Light Image")
108
+ down_size = gr.Slider(minimum=64, maximum=512, step=32, value=256, label="Down Size")
109
+ epochs = gr.Slider(minimum=10, maximum=500, step=1, value=100, label="Epochs")
110
+ window = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Window Size")
111
+ L = gr.Slider(minimum=0.1, maximum=1.0, step=0.01, value=0.5, label="L (Optimally-Intense Threshold)")
112
+ alpha = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Alpha (Fidelity Control)")
113
+ beta = gr.Slider(minimum=1.0, maximum=100.0, step=1.0, value=20.0, label="Beta (Illumination Smoothness)")
114
+ gamma = gr.Slider(minimum=1.0, maximum=50.0, step=1.0, value=8.0, label="Gamma (Exposure Control)")
115
+ delta = gr.Slider(minimum=1.0, maximum=50.0, step=1.0, value=5.0, label="Delta (Sparsity Level)")
116
+ enhance_btn = gr.Button("Enhance")
117
+
118
+ with gr.Column():
119
+ output_image = gr.Image(label="Enhanced Image")
120
+ output_download = gr.File(label="Download Output Image")
121
+ # Optional: Input download not necessary, as user uploaded it. Right-click on input image works too.
122
+
123
+ # Examples section (grid with one example)
124
+ # Assume you have an example image in the repo, e.g., "examples/low_light_example.png"
125
+ # For demo, placeholder. Replace with actual path.
126
+ examples = gr.Examples(
127
+ examples=[
128
+ ["examples/low_light_example.png", 256, 100, 1, 0.5, 1.0, 20.0, 8.0, 5.0]
129
+ ],
130
+ inputs=[input_image, down_size, epochs, window, L, alpha, beta, gamma, delta],
131
+ label="Examples (Click to load image and hyperparameters)"
132
+ )
133
+
134
+ enhance_btn.click(
135
+ enhance_image,
136
+ inputs=[input_image, down_size, epochs, window, L, alpha, beta, gamma, delta],
137
+ outputs=[output_image, output_download]
138
+ )
139
+
140
+ demo.launch()
colie.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from loss import *
3
+ from siren import INF
4
+ from color import rgb2hsv_torch, hsv2rgb_torch
5
+
6
+ import os
7
+ import argparse
8
+ import numpy as np
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+
13
+ parser = argparse.ArgumentParser(description='CoLIE')
14
+ parser.add_argument('--input_folder', type=str, default='input/')
15
+ parser.add_argument('--output_folder', type=str, default='output/')
16
+ parser.add_argument('--down_size', type=int, default=256, help='downsampling size')
17
+ parser.add_argument('--epochs', type=int, default=100)
18
+ parser.add_argument('--window', type=int, default=1, help='context window size')
19
+ parser.add_argument('--L', type=float, default=0.5)
20
+ # loss fuction weigth parameters
21
+ parser.add_argument('--alpha', type=float, required=True)
22
+ parser.add_argument('--beta', type=float, required=True)
23
+ parser.add_argument('--gamma', type=float, required=True)
24
+ parser.add_argument('--delta', type=float, required=True)
25
+ opt = parser.parse_args()
26
+
27
+
28
+ if not os.path.exists(opt.input_folder):
29
+ print('input folder: {} does not exist'.format(opt.input_folder))
30
+ exit()
31
+
32
+ if not os.path.exists(opt.output_folder):
33
+ os.makedirs(opt.output_folder)
34
+
35
+
36
+ print(' > running')
37
+ for PATH in tqdm(np.sort(os.listdir(opt.input_folder))):
38
+ img_rgb = get_image(os.path.join(opt.input_folder, PATH))
39
+ img_hsv = rgb2hsv_torch(img_rgb)
40
+
41
+ img_v = get_v_component(img_hsv)
42
+ img_v_lr = interpolate_image(img_v, opt.down_size, opt.down_size)
43
+ coords = get_coords(opt.down_size, opt.down_size)
44
+ patches = get_patches(img_v_lr, opt.window)
45
+
46
+
47
+ img_siren = INF(patch_dim=opt.window**2, num_layers=4, hidden_dim=256, add_layer=2)
48
+
49
+ optimizer = torch.optim.Adam(img_siren.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=3e-4)
50
+
51
+ l_exp = L_exp(16,opt.L)
52
+ l_TV = L_TV()
53
+
54
+ for epoch in range(opt.epochs):
55
+ img_siren.train()
56
+ optimizer.zero_grad()
57
+
58
+ illu_res_lr = img_siren(patches, coords)
59
+ illu_res_lr = illu_res_lr.view(1,1,opt.down_size,opt.down_size)
60
+ illu_lr = illu_res_lr + img_v_lr
61
+
62
+ img_v_fixed_lr = (img_v_lr) / (illu_lr + 1e-4)
63
+
64
+ loss_spa = torch.mean(torch.abs(torch.pow(illu_lr - img_v_lr, 2)))
65
+ loss_tv = l_TV(illu_lr)
66
+ loss_exp = torch.mean(l_exp(illu_lr))
67
+ loss_sparsity = torch.mean(img_v_fixed_lr)
68
+
69
+
70
+ loss = loss_spa * opt.alpha + loss_tv * opt.beta + loss_exp * opt.gamma + loss_sparsity * opt.delta
71
+ loss.backward()
72
+ optimizer.step()
73
+
74
+
75
+ img_v_fixed = filter_up(img_v_lr, img_v_fixed_lr, img_v)
76
+ img_hsv_fixed = replace_v_component(img_hsv, img_v_fixed)
77
+ img_rgb_fixed = hsv2rgb_torch(img_hsv_fixed)
78
+ img_rgb_fixed = img_rgb_fixed / torch.max(img_rgb_fixed)
79
+
80
+ Image.fromarray(
81
+ (torch.movedim(img_rgb_fixed,1,-1)[0].detach().cpu().numpy() * 255).astype(np.uint8)
82
+ ).save(os.path.join(opt.output_folder, PATH))
83
+
84
+ print(' > reconstruction done')
color.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def rgb2hsv_torch(rgb: torch.Tensor) -> torch.Tensor:
4
+ cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
5
+ cmin = torch.min(rgb, dim=1, keepdim=True)[0]
6
+ delta = cmax - cmin
7
+ hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
8
+ cmax_idx[delta == 0] = 3
9
+ hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
10
+ hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
11
+ hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
12
+ hsv_h[cmax_idx == 3] = 0.
13
+ hsv_h /= 6.
14
+ hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
15
+ hsv_v = cmax
16
+ return torch.cat([hsv_h, hsv_s, hsv_v], dim=1)
17
+
18
+
19
+ def hsv2rgb_torch(hsv: torch.Tensor) -> torch.Tensor:
20
+ hsv_h, hsv_s, hsv_l = hsv[:, 0:1], hsv[:, 1:2], hsv[:, 2:3]
21
+ _c = hsv_l * hsv_s
22
+ _x = _c * (- torch.abs(hsv_h * 6. % 2. - 1) + 1.)
23
+ _m = hsv_l - _c
24
+ _o = torch.zeros_like(_c)
25
+ idx = (hsv_h * 6.).type(torch.uint8)
26
+ idx = (idx % 6).expand(-1, 3, -1, -1)
27
+ rgb = torch.empty_like(hsv)
28
+ rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0]
29
+ rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1]
30
+ rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2]
31
+ rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3]
32
+ rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4]
33
+ rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5]
34
+ rgb += _m
35
+ return rgb
filter.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.autograd import Variable
5
+
6
+ def diff_x(input, r):
7
+ assert input.dim() == 4
8
+
9
+ left = input[:, :, r:2 * r + 1]
10
+ middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
11
+ right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
12
+
13
+ output = torch.cat([left, middle, right], dim=2)
14
+
15
+ return output
16
+
17
+ def diff_y(input, r):
18
+ assert input.dim() == 4
19
+
20
+ left = input[:, :, :, r:2 * r + 1]
21
+ middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
22
+ right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
23
+
24
+ output = torch.cat([left, middle, right], dim=3)
25
+
26
+ return output
27
+
28
+ class BoxFilter(nn.Module):
29
+ def __init__(self, r):
30
+ super(BoxFilter, self).__init__()
31
+
32
+ self.r = r
33
+
34
+ def forward(self, x):
35
+ assert x.dim() == 4
36
+
37
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
38
+
39
+
40
+ class FastGuidedFilter(nn.Module):
41
+ def __init__(self, r, eps=1e-8):
42
+ super(FastGuidedFilter, self).__init__()
43
+
44
+ self.r = r
45
+ self.eps = eps
46
+ self.boxfilter = BoxFilter(r)
47
+
48
+
49
+ def forward(self, lr_x, lr_y, hr_x):
50
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
51
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
52
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
53
+
54
+ assert n_lrx == n_lry and n_lry == n_hrx
55
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
56
+ assert h_lrx == h_lry and w_lrx == w_lry
57
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
58
+
59
+ ## N
60
+ N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
61
+
62
+ ## mean_x
63
+ mean_x = self.boxfilter(lr_x) / N
64
+ ## mean_y
65
+ mean_y = self.boxfilter(lr_y) / N
66
+ ## cov_xy
67
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
68
+ ## var_x
69
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
70
+
71
+ ## A
72
+ A = cov_xy / (var_x + self.eps)
73
+ ## b
74
+ b = mean_y - A * mean_x
75
+
76
+ ## mean_A; mean_b
77
+ mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
78
+ mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
79
+
80
+ return mean_A*hr_x+mean_b
loss.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L_exp(nn.Module):
7
+ def __init__(self, patch_size, mean_val):
8
+ super(L_exp, self).__init__()
9
+ self.pool = nn.AvgPool2d(patch_size)
10
+ self.mean_val = mean_val
11
+
12
+ def forward(self, x):
13
+ mean = self.pool(x) ** 0.5
14
+ d = torch.abs(torch.mean(torch.pow(mean - torch.FloatTensor([self.mean_val] ),2)))
15
+ return d
16
+
17
+
18
+ class L_TV(nn.Module):
19
+ def __init__(self):
20
+ super(L_TV,self).__init__()
21
+
22
+ def forward(self,x):
23
+ batch_size = x.size()[0]
24
+ h_x = x.size()[2]
25
+ w_x = x.size()[3]
26
+ count_h = (x.size()[2] - 1) * x.size()[3]
27
+ count_w = x.size()[2] * (x.size()[3] - 1)
28
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
29
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
30
+ return 2*(h_tv/count_h+w_tv/count_w)/batch_size
31
+
siren.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class SirenLayer(nn.Module):
7
+ def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
8
+ super().__init__()
9
+ self.in_f = in_f
10
+ self.w0 = w0
11
+ self.linear = nn.Linear(in_f, out_f)
12
+ self.is_first = is_first
13
+ self.is_last = is_last
14
+ if not self.is_last:
15
+ self.init_weights()
16
+
17
+ def init_weights(self):
18
+ b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
19
+ with torch.no_grad():
20
+ self.linear.weight.uniform_(-b, b)
21
+
22
+ def forward(self, x):
23
+ x = self.linear(x)
24
+ return nn.Sigmoid()(x) if self.is_last else torch.sin(self.w0 * x)
25
+
26
+
27
+ class INF(nn.Module):
28
+ def __init__(self, patch_dim, num_layers, hidden_dim, add_layer, weight_decay=None):
29
+ super().__init__()
30
+ '''
31
+ `add_layer` should be in range of [1, num_layers-2]
32
+ '''
33
+
34
+ patch_layers = [SirenLayer(patch_dim, hidden_dim, is_first=True)]
35
+ spatial_layers = [SirenLayer(2, hidden_dim, is_first=True)]
36
+ output_layers = []
37
+
38
+ for _ in range(1, add_layer - 2):
39
+ patch_layers.append(SirenLayer(hidden_dim, hidden_dim))
40
+ spatial_layers.append(SirenLayer(hidden_dim, hidden_dim))
41
+ patch_layers.append(SirenLayer(hidden_dim, hidden_dim//2))
42
+ spatial_layers.append(SirenLayer(hidden_dim, hidden_dim//2))
43
+
44
+ for _ in range(add_layer, num_layers - 1):
45
+ output_layers.append(SirenLayer(hidden_dim, hidden_dim))
46
+ output_layers.append(SirenLayer(hidden_dim, 1, is_last=True))
47
+
48
+ self.patch_net = nn.Sequential(*patch_layers)
49
+ self.spatial_net = nn.Sequential(*spatial_layers)
50
+ self.output_net = nn.Sequential(*output_layers)
51
+
52
+ if not weight_decay:
53
+ weight_decay = [0.1, 0.0001, 0.001]
54
+
55
+ self.params = []
56
+ self.params += [{'params':self.spatial_net.parameters(),'weight_decay':weight_decay[0]}]
57
+ self.params += [{'params':self.patch_net.parameters(),'weight_decay':weight_decay[1]}]
58
+ self.params += [{'params':self.output_net.parameters(),'weight_decay':weight_decay[2]}]
59
+
60
+ def forward(self, patch, spatial):
61
+ return self.output_net(torch.cat((self.patch_net(patch), self.spatial_net(spatial)), -1))
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from filter import FastGuidedFilter
8
+
9
+
10
+ def get_image(path):
11
+ """
12
+ Reads and returns RGB image, (1,3,H,W).
13
+ """
14
+ image = torch.from_numpy(np.array(Image.open(path))).float()
15
+ image = image / torch.max(image)
16
+ image = torch.movedim(image, -1, 0).unsqueeze(0)
17
+ return image
18
+
19
+
20
+ def get_v_component(img_hsv):
21
+ """
22
+ Assumes (1,3,H,W) HSV image.
23
+ """
24
+ return img_hsv[:,-1].unsqueeze(0)
25
+
26
+
27
+ def replace_v_component(img_hsv, v_new):
28
+ """
29
+ Replaces the V component of a HSV image (1,3,H,W).
30
+ """
31
+ img_hsv[:,-1] = v_new
32
+ return img_hsv
33
+
34
+
35
+ def interpolate_image(img, H, W):
36
+ """
37
+ Reshapes the image based on new resolution.
38
+ """
39
+ return F.interpolate(img, size=(H,W))
40
+
41
+
42
+ def get_coords(H, W):
43
+ """
44
+ Creates a coordinates grid for INF.
45
+ """
46
+ coords = np.dstack(np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W)))
47
+ coords = torch.from_numpy(coords).float()
48
+ return coords
49
+
50
+
51
+ def get_patches(img, KERNEL_SIZE):
52
+ """
53
+ Creates a tensor where the channel contains patch information.
54
+ """
55
+ kernel = torch.zeros((KERNEL_SIZE ** 2, 1, KERNEL_SIZE, KERNEL_SIZE))
56
+
57
+ for i in range(KERNEL_SIZE):
58
+ for j in range(KERNEL_SIZE):
59
+ kernel[int(torch.sum(kernel).item()),0,i,j] = 1
60
+
61
+ pad = nn.ReflectionPad2d(KERNEL_SIZE//2)
62
+ im_padded = pad(img)
63
+
64
+ extracted = torch.nn.functional.conv2d(im_padded, kernel, padding=0).squeeze(0)
65
+
66
+ return torch.movedim(extracted, 0, -1)
67
+
68
+
69
+ def filter_up(x_lr, y_lr, x_hr, r=1):
70
+ """
71
+ Applies the guided filter to upscale the predicted image.
72
+ """
73
+ guided_filter = FastGuidedFilter(r=r)
74
+ y_hr = guided_filter(x_lr, y_lr, x_hr)
75
+ y_hr = torch.clip(y_hr, 0, 1)
76
+ return y_hr