import torch from torch import nn from torch.nn import functional as F from torch.autograd import Variable def diff_x(input, r): assert input.dim() == 4 left = input[:, :, r:2 * r + 1] middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1] right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1] output = torch.cat([left, middle, right], dim=2) return output def diff_y(input, r): assert input.dim() == 4 left = input[:, :, :, r:2 * r + 1] middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1] right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1] output = torch.cat([left, middle, right], dim=3) return output class BoxFilter(nn.Module): def __init__(self, r): super(BoxFilter, self).__init__() self.r = r def forward(self, x): assert x.dim() == 4 return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r) class FastGuidedFilter(nn.Module): def __init__(self, r, eps=1e-8): super(FastGuidedFilter, self).__init__() self.r = r self.eps = eps self.boxfilter = BoxFilter(r) def forward(self, lr_x, lr_y, hr_x): n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size() n_lry, c_lry, h_lry, w_lry = lr_y.size() n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size() assert n_lrx == n_lry and n_lry == n_hrx assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry) assert h_lrx == h_lry and w_lrx == w_lry assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1 ## N N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))) ## mean_x mean_x = self.boxfilter(lr_x) / N ## mean_y mean_y = self.boxfilter(lr_y) / N ## cov_xy cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y ## var_x var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x ## A A = cov_xy / (var_x + self.eps) ## b b = mean_y - A * mean_x ## mean_A; mean_b mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True) mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True) return mean_A*hr_x+mean_b