from transformers import PreTrainedModel import torch import torch.nn as nn import torch.nn.functional as F import math from .wavelet import WaveletTransform from .pfsq import PFSQ from .config import PLPQConfig class PLPQ(PreTrainedModel): """Pyramidal Local Patch Quantizer""" config_class = PLPQConfig def __init__(self, config): super().__init__(config) self.config = config if config.__dict__.get('use_wavelets', False): wavelets = WaveletTransform(patch_size=config.patch_size) wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels) in_proj = nn.Sequential( wavelets, nn.Conv2d( wavelet_channels, config.encoder_blocks[0][1], kernel_size=1, stride=1 # keep fully local ) ) out_proj = nn.Sequential( nn.Conv2d( config.decoder_blocks[-1][2], wavelet_channels, kernel_size=3, stride=1, padding=1 ), WaveletTransform(patch_size=config.patch_size, inverse=True) ) else: in_proj = nn.Conv2d( config.num_in_channels, config.encoder_blocks[0][1], kernel_size=config.patch_size, stride=config.patch_size ) out_proj = nn.Conv2d( config.decoder_blocks[-1][2], config.num_out_channels, kernel_size=3, stride=1, padding=1 ) self.encoder = nn.Sequential( in_proj, nn.SiLU(), *[ PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Downsample(*block_params[1:]) for block_params in config.encoder_blocks ] ) # Pyramidal Quantizer self.quantizer = PFSQ( levels = config.levels, # number of levels for each codebook num_codebooks = config.num_quantizers, # number of quantizers dim = config.encoder_blocks[-1][2], # this is the input feature dimension, defaults to log2(codebook_size) if not defined ) # Coarse decoder output -> 32x32 supervision self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1) self.decoder = nn.Sequential( *[ PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Upsample(*block_params[1:]) for block_params in config.decoder_blocks ], out_proj ) def get_num_params(self) -> int: """Return the number of parameters in the model.""" return sum(p.numel() for p in self.parameters()) @torch.no_grad() def quantize(self, x: torch.Tensor) -> torch.Tensor: """ Quantize the input tensor Parameters: x (torch.Tensor): The input tensor of shape (b, c, h, w) Returns: torch.Tensor: The indices tensor of shape (b, t, n_quantizers) """ z = self.encoder(x).permute(0, 2, 3, 1).contiguous() b, h, w, c = z.shape z = z.view(b, h * w, -1) quantized, coarse_quantized, all_codes = self.quantizer(z) return all_codes @torch.no_grad() def decode(self, indices: torch.Tensor) -> torch.Tensor: """ Decode a tensor, inverse of self.quantize Parameters: indices (torch.Tensor): The input codes of shape (b, t, n_quantizers) Returns: torch.Tensor: The decoded tensor of shape (b, c, h, w) """ ncodes = indices.shape[-1] emb = self.quantizer.indices_to_codes(indices).squeeze(-1) # reshape [b t c] -> [b c h w] b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1))) emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous() if ncodes == 1: return self.coarse_decoder(emb) # full decoder: full image prediction return self.decoder(emb) class LayerNorm(nn.Module): """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) class PatchResidualConvBlock(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, kernel_size, stride, padding, dorpout=0.1) -> None: super().__init__() self.nonlinearity = nn.SiLU() self.ln1 = LayerNorm(in_dim, bias=True) self.dropout = nn.Dropout(dorpout) self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding) self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): b, c, h, w = x.shape z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)).reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous() z = self.nonlinearity(self.conv1(z)) z = self.dropout(z) z = self.nonlinearity(self.conv2(z)) return z + x class Upsample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x