from PIL import Image import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from filter import FastGuidedFilter def get_image(path): """ Reads and returns RGB image, (1,3,H,W). """ image = torch.from_numpy(np.array(Image.open(path))).float() image = image / torch.max(image) image = torch.movedim(image, -1, 0).unsqueeze(0) return image def get_v_component(img_hsv): """ Assumes (1,3,H,W) HSV image. """ return img_hsv[:,-1].unsqueeze(0) def replace_v_component(img_hsv, v_new): """ Replaces the V component of a HSV image (1,3,H,W). """ img_hsv[:,-1] = v_new return img_hsv def interpolate_image(img, H, W): """ Reshapes the image based on new resolution. """ return F.interpolate(img, size=(H,W)) def get_coords(H, W): """ Creates a coordinates grid for INF. """ coords = np.dstack(np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W))) coords = torch.from_numpy(coords).float() return coords def get_patches(img, KERNEL_SIZE): """ Creates a tensor where the channel contains patch information. """ kernel = torch.zeros((KERNEL_SIZE ** 2, 1, KERNEL_SIZE, KERNEL_SIZE)) for i in range(KERNEL_SIZE): for j in range(KERNEL_SIZE): kernel[int(torch.sum(kernel).item()),0,i,j] = 1 pad = nn.ReflectionPad2d(KERNEL_SIZE//2) im_padded = pad(img) extracted = torch.nn.functional.conv2d(im_padded, kernel, padding=0).squeeze(0) return torch.movedim(extracted, 0, -1) def filter_up(x_lr, y_lr, x_hr, r=1): """ Applies the guided filter to upscale the predicted image. """ guided_filter = FastGuidedFilter(r=r) y_hr = guided_filter(x_lr, y_lr, x_hr) y_hr = torch.clip(y_hr, 0, 1) return y_hr