TheTrueJard commited on
Commit
748c921
·
verified ·
1 Parent(s): 21bb585

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. __init__.py +5 -0
  2. config.json +9 -0
  3. config.py +32 -0
  4. pfsq.py +234 -0
  5. plpq.py +196 -0
  6. wavelet.py +167 -0
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from .plpq import PLPQ
3
+ from .pfsq import PFSQ
4
+ from .config import PLPQConfig
5
+ from .wavelet import WaveletTransform
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "StanfordNeuroAILab/PLPQ",
3
+ "architectures": ["PLPQ"],
4
+ "auto_map": {
5
+ "AutoConfig": "config.PLPQConfig",
6
+ "AutoModel": "plpq.PLPQ"
7
+ },
8
+ "model_type": "PLPQ"
9
+ }
config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List
3
+ from transformers import PretrainedConfig
4
+
5
+ class PLPQConfig(PretrainedConfig):
6
+ model_type: str = "PLPQ"
7
+ def __init__(self,
8
+ image_size: List[int, int],
9
+ patch_size: int,
10
+ dropout: float,
11
+ vocab_size: int,
12
+ levels: List[int],
13
+ num_quantizers: int,
14
+ num_in_channels: int,
15
+ num_out_channels: int,
16
+ use_wavelets: bool,
17
+ encoder_blocks: List[List],
18
+ decoder_blocks: List[List],
19
+ **kwargs
20
+ ):
21
+ image_size = image_size
22
+ patch_size = patch_size
23
+ dropout = dropout
24
+ vocab_size = vocab_size
25
+ levels = levels
26
+ num_quantizers = num_quantizers
27
+ num_in_channels = num_in_channels
28
+ num_out_channels = num_out_channels
29
+ use_wavelets = use_wavelets
30
+ encoder_blocks = encoder_blocks
31
+ decoder_blocks = decoder_blocks
32
+ super.__init__(**kwargs)
pfsq.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.cuda.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+ def default(*args):
25
+ for arg in args:
26
+ if exists(arg):
27
+ return arg
28
+ return None
29
+
30
+ def maybe(fn):
31
+ @wraps(fn)
32
+ def inner(x, *args, **kwargs):
33
+ if not exists(x):
34
+ return x
35
+ return fn(x, *args, **kwargs)
36
+ return inner
37
+
38
+ def pack_one(t, pattern):
39
+ return pack([t], pattern)
40
+
41
+ def unpack_one(t, ps, pattern):
42
+ return unpack(t, ps, pattern)[0]
43
+
44
+ # tensor helpers
45
+
46
+ def round_ste(z: Tensor) -> Tensor:
47
+ """Round with straight through gradients."""
48
+ zhat = z.round()
49
+ return z + (zhat - z).detach()
50
+
51
+ # main class
52
+
53
+ class PFSQ(Module):
54
+ def __init__(
55
+ self,
56
+ levels: List[int],
57
+ dim: int | None = None,
58
+ num_codebooks = 1,
59
+ keep_num_codebooks_dim: bool | None = None,
60
+ scale: float | None = None,
61
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
62
+ channel_first: bool = False,
63
+ projection_has_bias: bool = True,
64
+ return_indices = True,
65
+ force_quantization_f32 = True
66
+ ):
67
+ super().__init__()
68
+ _levels = torch.tensor(levels, dtype=int32)
69
+ self.register_buffer("_levels", _levels, persistent = False)
70
+
71
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
72
+ self.register_buffer("_basis", _basis, persistent = False)
73
+
74
+ self.scale = scale
75
+
76
+ codebook_dim = len(levels)
77
+ self.codebook_dim = codebook_dim
78
+
79
+ effective_codebook_dim = codebook_dim * num_codebooks
80
+ self.num_codebooks = num_codebooks
81
+ self.effective_codebook_dim = effective_codebook_dim
82
+
83
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
84
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
85
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
86
+
87
+ self.dim = default(dim, len(_levels) * num_codebooks)
88
+
89
+ self.channel_first = channel_first
90
+
91
+ has_projections = self.dim != effective_codebook_dim
92
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
93
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
94
+
95
+ self.has_projections = has_projections
96
+
97
+ self.return_indices = return_indices
98
+ if return_indices:
99
+ self.codebook_size = self._levels.prod().item()
100
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
101
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
102
+
103
+ self.allowed_dtypes = allowed_dtypes
104
+ self.force_quantization_f32 = force_quantization_f32
105
+
106
+ def bound(self, z, eps: float = 1e-3):
107
+ """ Bound `z`, an array of shape (..., d). """
108
+ half_l = (self._levels - 1) * (1 + eps) / 2
109
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
110
+ shift = (offset / half_l).atanh()
111
+ return (z + shift).tanh() * half_l - offset
112
+
113
+ def quantize(self, z):
114
+ """ Quantizes z, returns quantized zhat, same shape as z. """
115
+ quantized = round_ste(self.bound(z))
116
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
117
+ return quantized / half_width
118
+
119
+ def _scale_and_shift(self, zhat_normalized):
120
+ half_width = self._levels // 2
121
+ return (zhat_normalized * half_width) + half_width
122
+
123
+ def _scale_and_shift_inverse(self, zhat):
124
+ half_width = self._levels // 2
125
+ return (zhat - half_width) / half_width
126
+
127
+ def _indices_to_codes(self, indices):
128
+ level_indices = self.indices_to_level_indices(indices)
129
+ codes = self._scale_and_shift_inverse(level_indices)
130
+ return codes
131
+
132
+ def codes_to_indices(self, zhat):
133
+ """ Converts a `code` to an index in the codebook. """
134
+ assert zhat.shape[-1] == self.codebook_dim
135
+ zhat = self._scale_and_shift(zhat)
136
+ return (zhat * self._basis).sum(dim=-1).to(int32)
137
+
138
+ def indices_to_level_indices(self, indices):
139
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
140
+ indices = rearrange(indices, '... -> ... 1')
141
+ codes_non_centered = (indices // self._basis) % self._levels
142
+ return codes_non_centered
143
+
144
+ def indices_to_codes(self, indices, return_first=False):
145
+ """ Inverse of `codes_to_indices`. """
146
+ assert exists(indices)
147
+
148
+ n_codes = indices.shape[-1]
149
+
150
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
151
+
152
+ codes = self._indices_to_codes(indices)
153
+
154
+ if self.keep_num_codebooks_dim:
155
+ codes = rearrange(codes, '... c d -> ... (c d)')
156
+
157
+ if n_codes == 1:
158
+ return codes
159
+
160
+ codes = self.project_out(codes)
161
+
162
+ if is_img_or_video or self.channel_first:
163
+ codes = rearrange(codes, 'b ... d -> b d ...')
164
+
165
+ return codes
166
+
167
+ @autocast(enabled = False)
168
+ def forward(self, z):
169
+ """
170
+ einstein notation
171
+ b - batch
172
+ n - sequence (or flattened spatial dimensions)
173
+ d - feature dimension
174
+ c - number of codebook dim
175
+ """
176
+
177
+ is_img_or_video = z.ndim >= 4
178
+ need_move_channel_last = is_img_or_video or self.channel_first
179
+
180
+ # standardize image or video into (batch, seq, dimension)
181
+
182
+ if need_move_channel_last:
183
+ z = rearrange(z, 'b d ... -> b ... d')
184
+ z, ps = pack_one(z, 'b * d')
185
+
186
+ assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
187
+
188
+ z = self.project_in(z)
189
+
190
+ z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
191
+
192
+ # whether to force quantization step to be full precision or not
193
+
194
+ force_f32 = self.force_quantization_f32
195
+ quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
196
+
197
+ with quantization_context():
198
+ orig_dtype = z.dtype
199
+
200
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
201
+ z = z.float()
202
+
203
+ codes = self.quantize(z)
204
+
205
+ # returning indices could be optional
206
+
207
+ indices = None
208
+
209
+ if self.return_indices:
210
+ indices = self.codes_to_indices(codes)
211
+
212
+ first_codes = codes[:, :, 0, :] # first codebook
213
+ codes = rearrange(codes, 'b n c d -> b n (c d)')
214
+
215
+ codes = codes.type(orig_dtype)
216
+ first_codes = first_codes.type(orig_dtype)
217
+
218
+ # project out
219
+ out = self.project_out(codes)
220
+
221
+ # reconstitute image or video dimensions
222
+
223
+ if need_move_channel_last:
224
+ out = unpack_one(out, ps, 'b * d')
225
+ out = rearrange(out, 'b ... d -> b d ...')
226
+
227
+ indices = maybe(unpack_one)(indices, ps, 'b * c')
228
+
229
+ if not self.keep_num_codebooks_dim and self.return_indices:
230
+ indices = maybe(rearrange)(indices, '... 1 -> ...')
231
+
232
+ # return quantized output and indices
233
+
234
+ return out, first_codes, indices
plpq.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PreTrainedModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ from .wavelet import WaveletTransform
9
+ from .pfsq import PFSQ
10
+ from .config import PLPQConfig
11
+
12
+
13
+ class PLPQ(PreTrainedModel):
14
+ """
15
+ Pyramidal Local Patch Quantizer
16
+ """
17
+ config_class = PLPQConfig
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.config = config
22
+
23
+ if config.__dict__.get('use_wavelets', False):
24
+ wavelets = WaveletTransform(patch_size=config.patch_size)
25
+ wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels)
26
+ in_proj = nn.Sequential(
27
+ wavelets,
28
+ nn.Conv2d(
29
+ wavelet_channels, config.encoder_blocks[0][1],
30
+ kernel_size=1, stride=1 # keep fully local
31
+ )
32
+ )
33
+ out_proj = nn.Sequential(
34
+ nn.Conv2d(
35
+ config.decoder_blocks[-1][2], wavelet_channels,
36
+ kernel_size=3, stride=1, padding=1
37
+ ),
38
+ WaveletTransform(patch_size=config.patch_size, inverse=True)
39
+ )
40
+ else:
41
+ in_proj = nn.Conv2d(
42
+ config.num_in_channels, config.encoder_blocks[0][1],
43
+ kernel_size=config.patch_size, stride=config.patch_size
44
+ )
45
+ out_proj = nn.Conv2d(
46
+ config.decoder_blocks[-1][2], config.num_out_channels,
47
+ kernel_size=3, stride=1, padding=1
48
+ )
49
+
50
+ self.encoder = nn.Sequential(
51
+ in_proj,
52
+ nn.SiLU(),
53
+ *[
54
+ PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Downsample(*block_params[1:])
55
+ for block_params in config.encoder_blocks
56
+ ]
57
+ )
58
+
59
+ # Pyramidal Quantizer
60
+ self.quantizer = PFSQ(
61
+ levels = config.levels, # number of levels for each codebook
62
+ num_codebooks = config.num_quantizers, # number of quantizers
63
+ dim = config.encoder_blocks[-1][2], # this is the input feature dimension, defaults to log2(codebook_size) if not defined
64
+ )
65
+
66
+ # coarse decoder output -> 32x32 supervision
67
+ self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1)
68
+
69
+ self.decoder = nn.Sequential(
70
+ *[
71
+ PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Upsample(*block_params[1:])
72
+ for block_params in config.decoder_blocks
73
+ ],
74
+ out_proj
75
+ )
76
+
77
+
78
+ def get_num_params(self) -> int:
79
+ """
80
+ Return the number of parameters in the model.
81
+ """
82
+ return sum(p.numel() for p in self.parameters())
83
+
84
+
85
+ @torch.no_grad()
86
+ def quantize(self, x: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Quantize the input tensor
89
+ Parameters:
90
+ x (torch.Tensor): The input tensor. Size b, c, h, w
91
+ Returns:
92
+ torch.Tensor: The indices tensor. Size b, h, w
93
+ """
94
+ # encode the input
95
+ z = self.encoder(x).permute(0, 2, 3, 1).contiguous()
96
+ # reshape the input
97
+ b, h, w, c = z.shape
98
+ z = z.view(b, h * w, -1)
99
+
100
+ # quantize the input
101
+ quantized, coarse_quantized, all_codes = self.quantizer(z)
102
+
103
+ return all_codes
104
+
105
+
106
+ @torch.no_grad()
107
+ def decode(self, indices: torch.Tensor) -> torch.Tensor:
108
+ """
109
+ Parameters:
110
+ indices: torch.Tensor of shape (b, t, n_freq_bins)
111
+ Returns:
112
+ emb: torch.Tensor of shape (b, t, n_embd)
113
+ """
114
+
115
+ ncodes = indices.shape[-1]
116
+ emb = self.quantizer.indices_to_codes(indices).squeeze(-1)
117
+
118
+ # reshape [b t c] -> [b c h w]
119
+ b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1)))
120
+ emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous()
121
+
122
+ if ncodes == 1:
123
+ pred = self.coarse_decoder(emb)
124
+ return pred
125
+
126
+ # full decoder: full image prediction
127
+ pred = self.decoder(emb)
128
+
129
+ return pred
130
+
131
+
132
+
133
+ class LayerNorm(nn.Module):
134
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
135
+
136
+ def __init__(self, ndim, bias):
137
+ super().__init__()
138
+ self.weight = nn.Parameter(torch.ones(ndim))
139
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
140
+
141
+ def forward(self, input):
142
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
143
+
144
+
145
+
146
+ class PatchResidualConvBlock(nn.Module):
147
+
148
+ def __init__(self, in_dim, out_dim, hidden_dim, kernel_size, stride, padding, dorpout=0.1) -> None:
149
+ super().__init__()
150
+ self.nonlinearity = nn.SiLU()
151
+ self.ln1 = LayerNorm(in_dim, bias=True)
152
+ self.dropout = nn.Dropout(dorpout)
153
+ self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding)
154
+ self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding)
155
+
156
+ def forward(self, x):
157
+ b, c, h, w = x.shape
158
+ z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)).reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous()
159
+ z = self.nonlinearity(self.conv1(z))
160
+ z = self.dropout(z)
161
+ z = self.nonlinearity(self.conv2(z))
162
+ return z + x
163
+
164
+
165
+
166
+ class Upsample(nn.Module):
167
+ def __init__(self, in_channels, out_channels):
168
+ super().__init__()
169
+ self.conv = torch.nn.Conv2d(in_channels,
170
+ out_channels,
171
+ kernel_size=3,
172
+ stride=1,
173
+ padding=1)
174
+
175
+ def forward(self, x):
176
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
177
+ x = self.conv(x)
178
+ return x
179
+
180
+
181
+
182
+ class Downsample(nn.Module):
183
+ def __init__(self, in_channels, out_channels):
184
+ super().__init__()
185
+ # no asymmetric padding in torch conv, must do it ourselves
186
+ self.conv = torch.nn.Conv2d(in_channels,
187
+ out_channels,
188
+ kernel_size=3,
189
+ stride=2,
190
+ padding=0)
191
+
192
+ def forward(self, x):
193
+ pad = (0,1,0,1)
194
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
195
+ x = self.conv(x)
196
+ return x
wavelet.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import math
7
+
8
+
9
+ class WaveletTransform(nn.Module):
10
+
11
+ def __init__(self, patch_size: int, inverse: bool = False):
12
+ '''
13
+ `patchwise` in forward/invert makes *no difference*; the result
14
+ is numerically identical either way. It's still enabled by default
15
+ in case we pass in a non-square image, which may not be equivalent.
16
+ `reshape` is pretty much useless.
17
+ TODO: Clean up these options.
18
+ '''
19
+ super().__init__()
20
+ self.patch_size = patch_size
21
+ self.inverse = inverse
22
+ # From https://github.com/NVIDIA/Cosmos-Tokenizer/blob/3584ae752ce8ebdbe06a420bf60d7513c0e878cc/cosmos_tokenizer/modules/patching.py#L33
23
+ self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476])
24
+ self.arange = torch.arange(len(self.haar))
25
+ self.steps = int(math.log2(self.patch_size))
26
+
27
+ def num_transformed_channels(self, in_channels: int = 3) -> int:
28
+ '''
29
+ Returns the number of channels to expect in the transformed image
30
+ given the channels in the input image.
31
+ '''
32
+ return in_channels * (4 ** self.steps)
33
+
34
+
35
+ def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
36
+ if self.inverse:
37
+ return self.invert(x, patchwise=patchwise, from_reshaped=reshape)
38
+ else:
39
+ return self.transform(x, patchwise=patchwise, reshape=reshape)
40
+
41
+
42
+ def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
43
+ '''
44
+ ### Parameters:
45
+ `x`: ImageNet-normalized images with shape (B C H W)
46
+ `patchwise`: Whether to compute independently on patches
47
+ `reshape`: Reshape the results to match the input HxW
48
+ ### Returns:
49
+ If `reshape`, returns (B C H W)
50
+ otherwise, returns (B C*patch_size**2 H/patch_size W/patch_size)
51
+ '''
52
+ p = self.patch_size
53
+ if patchwise:
54
+ # Place patches into batch dimension
55
+ # (B C H W) -> (B*L C H/root(L), W/root(L))
56
+ b, c, h, w = x.shape
57
+ init_b = b
58
+ # (B C H W) -> (B C LH LW P P)
59
+ x = x.reshape(b, c, h//p, p, w//p, p).moveaxis(4,3)
60
+ # (B C LH LW P P) -> (B' C P P)
61
+ x = x.moveaxis(1,3).reshape(-1, c, p, p)
62
+
63
+ for _ in range(self.steps):
64
+ x = self.dwt(x)
65
+
66
+ if patchwise:
67
+ # Extract patches from batch dimension
68
+ # (B' C' 1 1) -> (B LH LW C') -> (B C' LH LW)
69
+ x = x.reshape(init_b, h//p, w//p, -1).moveaxis(3,1)
70
+ if reshape:
71
+ # (B C*patch_size**2 H/patch_size W/patch_size) -> (B C H W)
72
+ b, cp2, hdp, wdp = x.shape
73
+ c, h, w = cp2//(p**2), hdp*p, wdp*p
74
+ x = x.reshape(b, p, p, c, hdp, wdp)
75
+ x = x.moveaxis(3,1).moveaxis(3,4).reshape(b, c, h, w).contiguous()
76
+ return x
77
+
78
+ def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor:
79
+ '''
80
+ ### Parameters:
81
+ `x`: Wavelet-space input of either (B C H W) (when `from_reshaped=True`) or
82
+ (B C*patch_size**2 H/patch_size W/patch_size)
83
+ `patchwise`: Whether to compute independently on patches
84
+ `from_reshaped`: Determines the shape of `x`; should match the value of `reshape`
85
+ used when calling `forward`
86
+ '''
87
+ p = self.patch_size
88
+ if from_reshaped:
89
+ # (B C H W) -> (B C*patch_size**2 H/patch_size W/patch_size)
90
+ b, c, h, w = x.shape
91
+ cp2, hdp, wdp = c*self.patch_size**2, h//self.patch_size, w//self.patch_size
92
+ x = x.reshape(b, c, self.patch_size, hdp, self.patch_size, wdp)
93
+ x = x.moveaxis(4,3).moveaxis(1,3).reshape(b, cp2, hdp, wdp)
94
+ if patchwise:
95
+ # Put patches into batch dimension
96
+ # (B C' LH LW) -> (B LH LW C') -> (B' C' 1 1)
97
+ init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3]
98
+ x = x.moveaxis(1,3).reshape(-1, x.shape[1], 1, 1)
99
+
100
+ for _ in range(self.steps):
101
+ x = self.idwt(x)
102
+
103
+ if patchwise:
104
+ # Extract patches from batch dimension and expand
105
+ # (B' C P P) -> (B C LH LW P P)
106
+ x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3,1)
107
+ # (B C LH LW P P) -> (B C H W)
108
+ x = x.moveaxis(3,4).reshape(*x.shape[:2], lh*p, lw*p)
109
+ return x
110
+
111
+
112
+ def dwt(self, x: torch.Tensor):
113
+ dtype = x.dtype
114
+ h = self.haar
115
+
116
+ n = h.shape[0]
117
+ g = x.shape[1]
118
+ hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
119
+ hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
120
+ hh = hh.to(device=x.device, dtype=dtype)
121
+ hl = hl.to(device=x.device, dtype=dtype)
122
+
123
+ x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode='reflect').to(dtype)
124
+ xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
125
+ xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
126
+ xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
127
+ xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
128
+ xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
129
+ xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
130
+
131
+ return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1)
132
+
133
+
134
+ def idwt(self, x: torch.Tensor):
135
+ dtype = x.dtype
136
+ h = self.haar
137
+ n = h.shape[0]
138
+
139
+ g = x.shape[1] // 4
140
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
141
+ hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
142
+ hh = hh.to(device=x.device, dtype=dtype)
143
+ hl = hl.to(device=x.device, dtype=dtype)
144
+
145
+ xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
146
+
147
+ # Inverse transform.
148
+ yl = torch.nn.functional.conv_transpose2d(
149
+ xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
150
+ )
151
+ yl += torch.nn.functional.conv_transpose2d(
152
+ xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
153
+ )
154
+ yh = torch.nn.functional.conv_transpose2d(
155
+ xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
156
+ )
157
+ yh += torch.nn.functional.conv_transpose2d(
158
+ xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
159
+ )
160
+ y = torch.nn.functional.conv_transpose2d(
161
+ yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
162
+ )
163
+ y += torch.nn.functional.conv_transpose2d(
164
+ yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
165
+ )
166
+
167
+ return 2.0 * y