|
|
|
|
|
from transformers import PreTrainedModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
from .wavelet import WaveletTransform |
|
|
from .pfsq import PFSQ |
|
|
from .config import PLPQConfig |
|
|
|
|
|
|
|
|
class PLPQ(PreTrainedModel): |
|
|
"""Pyramidal Local Patch Quantizer""" |
|
|
config_class = PLPQConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
if config.__dict__.get('use_wavelets', False): |
|
|
wavelets = WaveletTransform(patch_size=config.patch_size) |
|
|
wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels) |
|
|
in_proj = nn.Sequential( |
|
|
wavelets, |
|
|
nn.Conv2d( |
|
|
wavelet_channels, config.encoder_blocks[0][1], |
|
|
kernel_size=1, stride=1 |
|
|
) |
|
|
) |
|
|
out_proj = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
config.decoder_blocks[-1][2], wavelet_channels, |
|
|
kernel_size=3, stride=1, padding=1 |
|
|
), |
|
|
WaveletTransform(patch_size=config.patch_size, inverse=True) |
|
|
) |
|
|
else: |
|
|
in_proj = nn.Conv2d( |
|
|
config.num_in_channels, config.encoder_blocks[0][1], |
|
|
kernel_size=config.patch_size, stride=config.patch_size |
|
|
) |
|
|
out_proj = nn.Conv2d( |
|
|
config.decoder_blocks[-1][2], config.num_out_channels, |
|
|
kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
in_proj, |
|
|
nn.SiLU(), |
|
|
*[ |
|
|
PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Downsample(*block_params[1:]) |
|
|
for block_params in config.encoder_blocks |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.quantizer = PFSQ( |
|
|
levels = config.levels, |
|
|
num_codebooks = config.num_quantizers, |
|
|
dim = config.encoder_blocks[-1][2], |
|
|
) |
|
|
|
|
|
|
|
|
self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1) |
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
*[ |
|
|
PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Upsample(*block_params[1:]) |
|
|
for block_params in config.decoder_blocks |
|
|
], |
|
|
out_proj |
|
|
) |
|
|
|
|
|
|
|
|
def get_num_params(self) -> int: |
|
|
"""Return the number of parameters in the model.""" |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def quantize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Quantize the input tensor |
|
|
Parameters: |
|
|
x (torch.Tensor): The input tensor of shape (b, c, h, w) |
|
|
Returns: |
|
|
torch.Tensor: The indices tensor of shape (b, t, n_quantizers) |
|
|
""" |
|
|
z = self.encoder(x).permute(0, 2, 3, 1).contiguous() |
|
|
b, h, w, c = z.shape |
|
|
z = z.view(b, h * w, -1) |
|
|
quantized, coarse_quantized, all_codes = self.quantizer(z) |
|
|
return all_codes |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def decode(self, indices: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Decode a tensor, inverse of self.quantize |
|
|
Parameters: |
|
|
indices (torch.Tensor): The input codes of shape (b, t, n_quantizers) |
|
|
Returns: |
|
|
torch.Tensor: The decoded tensor of shape (b, c, h, w) |
|
|
""" |
|
|
|
|
|
ncodes = indices.shape[-1] |
|
|
emb = self.quantizer.indices_to_codes(indices).squeeze(-1) |
|
|
|
|
|
b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1))) |
|
|
emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous() |
|
|
|
|
|
if ncodes == 1: |
|
|
return self.coarse_decoder(emb) |
|
|
|
|
|
|
|
|
return self.decoder(emb) |
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
|
|
|
def __init__(self, ndim, bias): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
|
|
def forward(self, input): |
|
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
|
|
|
|
|
|
|
class PatchResidualConvBlock(nn.Module): |
|
|
|
|
|
def __init__(self, in_dim, out_dim, hidden_dim, kernel_size, stride, padding, dorpout=0.1) -> None: |
|
|
super().__init__() |
|
|
self.nonlinearity = nn.SiLU() |
|
|
self.ln1 = LayerNorm(in_dim, bias=True) |
|
|
self.dropout = nn.Dropout(dorpout) |
|
|
self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
|
|
self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
|
|
|
|
|
def forward(self, x): |
|
|
b, c, h, w = x.shape |
|
|
z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)).reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
|
|
z = self.nonlinearity(self.conv1(z)) |
|
|
z = self.dropout(z) |
|
|
z = self.nonlinearity(self.conv2(z)) |
|
|
return z + x |
|
|
|
|
|
|
|
|
|
|
|
class Upsample(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.conv = torch.nn.Conv2d(in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
|
x = self.conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Downsample(nn.Module): |
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
|
|
|
self.conv = torch.nn.Conv2d(in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=0) |
|
|
|
|
|
def forward(self, x): |
|
|
pad = (0,1,0,1) |
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
|
x = self.conv(x) |
|
|
return x |
|
|
|