PLPQ / plpq.py
TheTrueJard's picture
Upload folder using huggingface_hub
f500667 verified
from transformers import PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .wavelet import WaveletTransform
from .pfsq import PFSQ
from .config import PLPQConfig
class PLPQ(PreTrainedModel):
"""Pyramidal Local Patch Quantizer"""
config_class = PLPQConfig
def __init__(self, config):
super().__init__(config)
self.config = config
if config.__dict__.get('use_wavelets', False):
wavelets = WaveletTransform(patch_size=config.patch_size)
wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels)
in_proj = nn.Sequential(
wavelets,
nn.Conv2d(
wavelet_channels, config.encoder_blocks[0][1],
kernel_size=1, stride=1 # keep fully local
)
)
out_proj = nn.Sequential(
nn.Conv2d(
config.decoder_blocks[-1][2], wavelet_channels,
kernel_size=3, stride=1, padding=1
),
WaveletTransform(patch_size=config.patch_size, inverse=True)
)
else:
in_proj = nn.Conv2d(
config.num_in_channels, config.encoder_blocks[0][1],
kernel_size=config.patch_size, stride=config.patch_size
)
out_proj = nn.Conv2d(
config.decoder_blocks[-1][2], config.num_out_channels,
kernel_size=3, stride=1, padding=1
)
self.encoder = nn.Sequential(
in_proj,
nn.SiLU(),
*[
PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Downsample(*block_params[1:])
for block_params in config.encoder_blocks
]
)
# Pyramidal Quantizer
self.quantizer = PFSQ(
levels = config.levels, # number of levels for each codebook
num_codebooks = config.num_quantizers, # number of quantizers
dim = config.encoder_blocks[-1][2], # this is the input feature dimension, defaults to log2(codebook_size) if not defined
)
# Coarse decoder output -> 32x32 supervision
self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1)
self.decoder = nn.Sequential(
*[
PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Upsample(*block_params[1:])
for block_params in config.decoder_blocks
],
out_proj
)
def get_num_params(self) -> int:
"""Return the number of parameters in the model."""
return sum(p.numel() for p in self.parameters())
@torch.no_grad()
def quantize(self, x: torch.Tensor) -> torch.Tensor:
"""
Quantize the input tensor
Parameters:
x (torch.Tensor): The input tensor of shape (b, c, h, w)
Returns:
torch.Tensor: The indices tensor of shape (b, t, n_quantizers)
"""
z = self.encoder(x).permute(0, 2, 3, 1).contiguous()
b, h, w, c = z.shape
z = z.view(b, h * w, -1)
quantized, coarse_quantized, all_codes = self.quantizer(z)
return all_codes
@torch.no_grad()
def decode(self, indices: torch.Tensor) -> torch.Tensor:
"""
Decode a tensor, inverse of self.quantize
Parameters:
indices (torch.Tensor): The input codes of shape (b, t, n_quantizers)
Returns:
torch.Tensor: The decoded tensor of shape (b, c, h, w)
"""
ncodes = indices.shape[-1]
emb = self.quantizer.indices_to_codes(indices).squeeze(-1)
# reshape [b t c] -> [b c h w]
b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1)))
emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous()
if ncodes == 1:
return self.coarse_decoder(emb)
# full decoder: full image prediction
return self.decoder(emb)
class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class PatchResidualConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, kernel_size, stride, padding, dorpout=0.1) -> None:
super().__init__()
self.nonlinearity = nn.SiLU()
self.ln1 = LayerNorm(in_dim, bias=True)
self.dropout = nn.Dropout(dorpout)
self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x):
b, c, h, w = x.shape
z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)).reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous()
z = self.nonlinearity(self.conv1(z))
z = self.dropout(z)
z = self.nonlinearity(self.conv2(z))
return z + x
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x