import torch import torch.nn as nn import torch.nn.functional as F class L_exp(nn.Module): def __init__(self, patch_size, mean_val): super(L_exp, self).__init__() self.pool = nn.AvgPool2d(patch_size) self.mean_val = mean_val def forward(self, x): mean = self.pool(x) ** 0.5 d = torch.abs(torch.mean(torch.pow(mean - torch.FloatTensor([self.mean_val] ),2))) return d class L_TV(nn.Module): def __init__(self): super(L_TV,self).__init__() def forward(self,x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = (x.size()[2] - 1) * x.size()[3] count_w = x.size()[2] * (x.size()[3] - 1) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() return 2*(h_tv/count_h+w_tv/count_w)/batch_size