coralLight commited on
Commit
1c3f916
·
1 Parent(s): 09c8989

xl version

Browse files
Files changed (7) hide show
  1. NoiseTransformer.py +26 -0
  2. README.md +17 -3
  3. SVDNoiseUnet.py +430 -0
  4. app.py +958 -56
  5. free_lunch_utils.py +304 -0
  6. requirements.txt +10 -6
  7. sdxl.pth +3 -0
NoiseTransformer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torch.nn import functional as F
4
+ from timm import create_model
5
+
6
+
7
+ __all__ = ['NoiseTransformer']
8
+
9
+ class NoiseTransformer(nn.Module):
10
+ def __init__(self, resolution=(128,96)):
11
+ super().__init__()
12
+ self.upsample = lambda x: F.interpolate(x, [224,224])
13
+ self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
14
+ self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
15
+ self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
16
+ # self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
17
+ self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
18
+
19
+
20
+ def forward(self, x, residual=False):
21
+ if residual:
22
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
23
+ else:
24
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
25
+
26
+ return x
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Hyperparameters Are All You Need Xl Version
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
@@ -8,7 +8,21 @@ sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: training-free algorithm for few step inference of sdxl
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Hyperparameters-are-all-you-need-xl-version-improved-implementation
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: training-free few step diffusion ODE solver
12
+ header: default
13
  ---
14
 
15
+ **Abstract:** The diffusion model is a state-of-the-art generative model that generates an image by applying a neural network iteratively. Moreover, this generation process is regarded as an algorithm solving an ordinary differential equation or a stochastic differential equation. Based on the analysis of the truncation error of the diffusion ODE and SDE, our study proposes a training-free algorithm that generates high-quality 512 x 512 and 1024 x 1024 images in eight steps, with flexible guidance scales. To the best of my knowledge, our algorithm is the first one that samples a 1024 x 1024 resolution image in 8 steps with an FID performance comparable to that of the latest distillation model, but without additional training. Meanwhile, our algorithm can also generate a 512 x 512 image in 8 steps, and its FID performance is better than the inference result using state-of-the-art ODE solver DPM++ 2m in 20 steps. We validate our eight-step image generation algorithm using the COCO 2014, COCO 2017, and LAION datasets. And our best FID performance is 15.7, 22.35, and 17.52. While the FID performance of DPM++2m is 17.3, 23.75, and 17.33. Further, it also outperforms the state-of-the-art AMED-plugin solver, whose FID performance is 19.07, 25.50, and 18.06. We also apply the algorithm in five-step inference without additional training, for which the best FID performance in the datasets mentioned above is 19.18, 23.24, and 19.61, respectively, and is comparable to the performance of the state-of-the-art AMED Pulgin solver in eight steps, SDXL-turbo in four steps, and the state-of-the-art diffusion distillation model Flash Diffusion in five steps. We also validate our algorithm in synthesizing 1024 * 1024 images within 6 steps, whose FID performance only has a limited distance to the latest distillation algorithm.
16
+
17
+ This is a demo is a simplified version of the approach described in the paper, ["Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model"](https://arxiv.org/abs/2510.02390)
18
+
19
+ ```
20
+ @misc{hyper,
21
+ title={Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model},
22
+ author={Zilai Li},
23
+ year={2025},
24
+ eprint={2510.02390},
25
+ archivePrefix={arXiv},
26
+ primaryClass={eess.IV}
27
+ }
28
+ ```
SVDNoiseUnet.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import einops
4
+
5
+ from torch.nn import functional as F
6
+ from torch.jit import Final
7
+ from timm.layers import use_fused_attn
8
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
9
+ from abc import abstractmethod
10
+ from NoiseTransformer import NoiseTransformer
11
+ from einops import rearrange
12
+ __all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
13
+
14
+ class Attention(nn.Module):
15
+ fused_attn: Final[bool]
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int = 8,
21
+ qkv_bias: bool = False,
22
+ qk_norm: bool = False,
23
+ attn_drop: float = 0.,
24
+ proj_drop: float = 0.,
25
+ norm_layer: nn.Module = nn.LayerNorm,
26
+ ) -> None:
27
+ super().__init__()
28
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
29
+ self.num_heads = num_heads
30
+ self.head_dim = dim // num_heads
31
+ self.scale = self.head_dim ** -0.5
32
+ self.fused_attn = use_fused_attn()
33
+
34
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
35
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
36
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
37
+ self.attn_drop = nn.Dropout(attn_drop)
38
+ self.proj = nn.Linear(dim, dim)
39
+ self.proj_drop = nn.Dropout(proj_drop)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ B, N, C = x.shape
43
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
44
+ q, k, v = qkv.unbind(0)
45
+ q, k = self.q_norm(q), self.k_norm(k)
46
+
47
+ if self.fused_attn:
48
+ x = F.scaled_dot_product_attention(
49
+ q, k, v,
50
+ dropout_p=self.attn_drop.p if self.training else 0.,
51
+ )
52
+ else:
53
+ q = q * self.scale
54
+ attn = q @ k.transpose(-2, -1)
55
+ attn = attn.softmax(dim=-1)
56
+ attn = self.attn_drop(attn)
57
+ x = attn @ v
58
+
59
+ x = x.transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class SVDNoiseUnet(nn.Module):
66
+ def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
67
+ super(SVDNoiseUnet, self).__init__()
68
+
69
+ _in_1 = int(resolution[0] * in_channels // 2)
70
+ _out_1 = int(resolution[0] * out_channels // 2)
71
+
72
+ _in_2 = int(resolution[1] * in_channels // 2)
73
+ _out_2 = int(resolution[1] * out_channels // 2)
74
+ self.mlp1 = nn.Sequential(
75
+ nn.Linear(_in_1, 64),
76
+ nn.ReLU(inplace=True),
77
+ nn.Linear(64, _out_1),
78
+ )
79
+ self.mlp2 = nn.Sequential(
80
+ nn.Linear(_in_2, 64),
81
+ nn.ReLU(inplace=True),
82
+ nn.Linear(64, _out_2),
83
+ )
84
+
85
+ self.mlp3 = nn.Sequential(
86
+ nn.Linear(_in_2, _out_2),
87
+ )
88
+
89
+ self.attention = Attention(_out_2)
90
+
91
+ self.bn = nn.BatchNorm1d(256)
92
+ self.bn2 = nn.BatchNorm1d(192)
93
+
94
+ self.mlp4 = nn.Sequential(
95
+ nn.Linear(_out_2, 1024),
96
+ nn.ReLU(inplace=True),
97
+ nn.Linear(1024, _out_2),
98
+ )
99
+ self.ffn = nn.Sequential(
100
+ nn.Linear(256, 384), # Expand
101
+ nn.ReLU(inplace=True),
102
+ nn.Linear(384, 192) # Reduce to target size
103
+ )
104
+ self.ffn2 = nn.Sequential(
105
+ nn.Linear(256, 384), # Expand
106
+ nn.ReLU(inplace=True),
107
+ nn.Linear(384, 192) # Reduce to target size
108
+ )
109
+ # self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
110
+
111
+ def forward(self, x, residual=False):
112
+ b, c, h, w = x.shape
113
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
114
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
115
+ U_T = U.permute(0, 2, 1)
116
+ U_out = self.ffn(self.mlp1(U_T))
117
+ U_out = self.bn(U_out)
118
+ U_out = U_out.transpose(1, 2)
119
+ U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
120
+ U_out = self.bn2(U_out)
121
+ U_out = U_out.transpose(1, 2)
122
+ # U_out = self.bn(U_out)
123
+ V_out = self.mlp2(V)
124
+ s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
125
+ out = U_out + V_out + s_out
126
+ # print(out.size())
127
+ out = out.squeeze(1)
128
+ out = self.attention(out).mean(1)
129
+ out = self.mlp4(out) + s
130
+ diagonal_out = torch.diag_embed(out)
131
+ padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
132
+ pred = U @ padded_diag @ V
133
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
134
+
135
+ class SVDNoiseUnet64(nn.Module):
136
+ def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
137
+ super(SVDNoiseUnet64, self).__init__()
138
+
139
+ _in = int(resolution * in_channels // 2)
140
+ _out = int(resolution * out_channels // 2)
141
+ self.mlp1 = nn.Sequential(
142
+ nn.Linear(_in, 64),
143
+ nn.ReLU(inplace=True),
144
+ nn.Linear(64, _out),
145
+ )
146
+ self.mlp2 = nn.Sequential(
147
+ nn.Linear(_in, 64),
148
+ nn.ReLU(inplace=True),
149
+ nn.Linear(64, _out),
150
+ )
151
+
152
+ self.mlp3 = nn.Sequential(
153
+ nn.Linear(_in, _out),
154
+ )
155
+
156
+ self.attention = Attention(_out)
157
+
158
+ self.bn = nn.BatchNorm2d(_out)
159
+
160
+ self.mlp4 = nn.Sequential(
161
+ nn.Linear(_out, 1024),
162
+ nn.ReLU(inplace=True),
163
+ nn.Linear(1024, _out),
164
+ )
165
+
166
+ def forward(self, x, residual=False):
167
+ b, c, h, w = x.shape
168
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
169
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
170
+ U_T = U.permute(0, 2, 1)
171
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
172
+ out = self.attention(out).mean(1)
173
+ out = self.mlp4(out) + s
174
+ pred = U @ torch.diag_embed(out) @ V
175
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
176
+
177
+
178
+
179
+ class SVDNoiseUnet128(nn.Module):
180
+ def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
181
+ super(SVDNoiseUnet128, self).__init__()
182
+
183
+ _in = int(resolution * in_channels // 2)
184
+ _out = int(resolution * out_channels // 2)
185
+ self.mlp1 = nn.Sequential(
186
+ nn.Linear(_in, 64),
187
+ nn.ReLU(inplace=True),
188
+ nn.Linear(64, _out),
189
+ )
190
+ self.mlp2 = nn.Sequential(
191
+ nn.Linear(_in, 64),
192
+ nn.ReLU(inplace=True),
193
+ nn.Linear(64, _out),
194
+ )
195
+
196
+ self.mlp3 = nn.Sequential(
197
+ nn.Linear(_in, _out),
198
+ )
199
+
200
+ self.attention = Attention(_out)
201
+
202
+ self.bn = nn.BatchNorm2d(_out)
203
+
204
+ self.mlp4 = nn.Sequential(
205
+ nn.Linear(_out, 1024),
206
+ nn.ReLU(inplace=True),
207
+ nn.Linear(1024, _out),
208
+ )
209
+
210
+ def forward(self, x, residual=False):
211
+ b, c, h, w = x.shape
212
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
213
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
214
+ U_T = U.permute(0, 2, 1)
215
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
216
+ out = self.attention(out).mean(1)
217
+ out = self.mlp4(out) + s
218
+ pred = U @ torch.diag_embed(out) @ V
219
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
220
+
221
+
222
+
223
+ class SVDNoiseUnet_Concise(nn.Module):
224
+ def __init__(self, in_channels=4, out_channels=4, resolution=64):
225
+ super(SVDNoiseUnet_Concise, self).__init__()
226
+
227
+
228
+ from diffusers.models.normalization import AdaGroupNorm
229
+
230
+ class NPNet(nn.Module):
231
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
232
+ super(NPNet, self).__init__()
233
+
234
+ assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
235
+
236
+ self.model_id = model_id
237
+ self.device = device
238
+ self.pretrained_path = pretrained_path
239
+
240
+ (
241
+ self.unet_svd,
242
+ self.unet_embedding,
243
+ self.text_embedding,
244
+ self._alpha,
245
+ self._beta
246
+ ) = self.get_model()
247
+ def save_model(self, save_path: str):
248
+ """
249
+ Save this NPNet so that get_model() can later reload it.
250
+ """
251
+ torch.save({
252
+ "unet_svd": self.unet_svd.state_dict(),
253
+ "unet_embedding": self.unet_embedding.state_dict(),
254
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
255
+ "alpha": self._alpha,
256
+ "beta": self._beta,
257
+ }, save_path)
258
+ print(f"NPNet saved to {save_path}")
259
+ def get_model(self):
260
+
261
+ unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
262
+ unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
263
+
264
+ if self.model_id == 'DiT':
265
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
266
+ else:
267
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
268
+
269
+ # initialize random _alpha and _beta when no checkpoint is provided
270
+ _alpha = torch.randn(1, device=self.device)
271
+ _beta = torch.randn(1, device=self.device)
272
+
273
+ if '.pth' in self.pretrained_path:
274
+ gloden_unet = torch.load(self.pretrained_path)
275
+ unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
276
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
277
+ text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
278
+ _alpha = gloden_unet["alpha"]
279
+ _beta = gloden_unet["beta"]
280
+
281
+ print("Load Successfully!")
282
+
283
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
284
+
285
+ else:
286
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
287
+
288
+
289
+ def forward(self, initial_noise, prompt_embeds):
290
+
291
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
292
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
293
+
294
+ encoder_hidden_states_svd = initial_noise
295
+ encoder_hidden_states_embedding = initial_noise + text_emb
296
+
297
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
298
+
299
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
300
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
301
+
302
+ return golden_noise
303
+
304
+
305
+ class NPNet64(nn.Module):
306
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
307
+ super(NPNet64, self).__init__()
308
+ self.model_id = model_id
309
+ self.device = device
310
+ self.pretrained_path = pretrained_path
311
+
312
+ (
313
+ self.unet_svd,
314
+ self.unet_embedding,
315
+ self.text_embedding,
316
+ self._alpha,
317
+ self._beta
318
+ ) = self.get_model()
319
+
320
+ def save_model(self, save_path: str):
321
+ """
322
+ Save this NPNet so that get_model() can later reload it.
323
+ """
324
+ torch.save({
325
+ "unet_svd": self.unet_svd.state_dict(),
326
+ "unet_embedding": self.unet_embedding.state_dict(),
327
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
328
+ "alpha": self._alpha,
329
+ "beta": self._beta,
330
+ }, save_path)
331
+ print(f"NPNet saved to {save_path}")
332
+
333
+ def get_model(self):
334
+
335
+ unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
336
+ unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
337
+ _alpha = torch.randn(1, device=self.device)
338
+ _beta = torch.randn(1, device=self.device)
339
+
340
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
341
+
342
+
343
+ if '.pth' in self.pretrained_path:
344
+ gloden_unet = torch.load(self.pretrained_path)
345
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
346
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
347
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
348
+ _alpha = gloden_unet["alpha"]
349
+ _beta = gloden_unet["beta"]
350
+
351
+ print("Load Successfully!")
352
+
353
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
354
+
355
+
356
+ def forward(self, initial_noise, prompt_embeds):
357
+
358
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
359
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
360
+
361
+ encoder_hidden_states_svd = initial_noise
362
+ encoder_hidden_states_embedding = initial_noise + text_emb
363
+
364
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
365
+
366
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
367
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
368
+
369
+ return golden_noise
370
+
371
+ class NPNet128(nn.Module):
372
+ def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
373
+ super(NPNet128, self).__init__()
374
+
375
+ assert model_id in ['SDXL', 'DreamShaper', 'DiT']
376
+
377
+ self.model_id = model_id
378
+ self.device = device
379
+ self.pretrained_path = pretrained_path
380
+
381
+ (
382
+ self.unet_svd,
383
+ self.unet_embedding,
384
+ self.text_embedding,
385
+ self._alpha,
386
+ self._beta
387
+ ) = self.get_model()
388
+
389
+ def get_model(self):
390
+
391
+ unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
392
+ unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
393
+
394
+ if self.model_id == 'DiT':
395
+ text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
396
+ else:
397
+ text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
398
+
399
+
400
+ if '.pth' in self.pretrained_path:
401
+ gloden_unet = torch.load(self.pretrained_path)
402
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
403
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
404
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
405
+ _alpha = gloden_unet["alpha"]
406
+ _beta = gloden_unet["beta"]
407
+
408
+ print("Load Successfully!")
409
+
410
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
411
+
412
+ else:
413
+ assert ("No Pretrained Weights Found!")
414
+
415
+
416
+ def forward(self, initial_noise, prompt_embeds):
417
+
418
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
419
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
420
+
421
+ encoder_hidden_states_svd = initial_noise
422
+ encoder_hidden_states_embedding = initial_noise + text_emb
423
+
424
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
425
+
426
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
427
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
428
+
429
+ return golden_noise
430
+
app.py CHANGED
@@ -1,60 +1,966 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
-
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
 
 
25
  def infer(
26
  prompt,
27
  negative_prompt,
28
  seed,
29
  randomize_seed,
30
- width,
31
- height,
32
  guidance_scale,
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
 
53
 
54
  examples = [
55
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css = """
@@ -66,7 +972,7 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -79,7 +985,13 @@ with gr.Blocks(css=css) as demo:
79
 
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
@@ -99,22 +1011,15 @@ with gr.Blocks(css=css) as demo:
99
 
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
  with gr.Row():
120
  guidance_scale = gr.Slider(
@@ -122,15 +1027,13 @@ with gr.Blocks(css=css) as demo:
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
- num_inference_steps = gr.Slider(
 
 
129
  label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt])
@@ -142,12 +1045,11 @@ with gr.Blocks(css=css) as demo:
142
  negative_prompt,
143
  seed,
144
  randomize_seed,
145
- width,
146
- height,
147
  guidance_scale,
148
  num_inference_steps,
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import json
5
+ import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import (
7
+ AutoencoderKL,
8
+ StableDiffusionXLPipeline,
9
+ )
10
+ from huggingface_hub import login, hf_hub_download
11
+ from PIL import Image
12
+ # from huggingface_hub import login
13
+ from SVDNoiseUnet import NPNet64
14
+ import functools
15
+ import random
16
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
17
  import torch
18
+ import torch.nn as nn
19
+ from einops import rearrange
20
+ from torchvision.utils import make_grid
21
+ import time
22
+ from pytorch_lightning import seed_everything
23
+ from torch import autocast
24
+ from contextlib import contextmanager, nullcontext
25
+ import accelerate
26
+ import torchsde
27
+ from SVDNoiseUnet import NPNet128
28
+ from tqdm import tqdm, trange
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" # Replace to the model you would like to use
31
 
32
+ precision_scope = autocast
33
+
34
+ def extract_into_tensor(a, t, x_shape):
35
+ b, *_ = t.shape
36
+ out = a.gather(-1, t)
37
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
38
 
 
 
39
 
40
+ def append_zero(x):
41
+ return torch.cat([x, x.new_zeros([1])])
42
+
43
+ # New helper to load a list-of-dicts preference JSON
44
+ # JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
45
+ def load_preference_json(json_path: str) -> list[dict]:
46
+ """Load records from a JSON file formatted as a list of preference dicts."""
47
+ with open(json_path, 'r') as f:
48
+ data = json.load(f)
49
+ return data
50
+
51
+ # New helper to extract just the prompts from the preference JSON
52
+ # Returns a flat list of all 'prompt' values
53
+
54
+ def extract_prompts_from_pref_json(json_path: str) -> list[str]:
55
+ """Load a JSON of preference records and return only the prompts."""
56
+ records = load_preference_json(json_path)
57
+ return [rec['prompt'] for rec in records]
58
+
59
+ # Example usage:
60
+ # prompts = extract_prompts_from_pref_json("path/to/preference.json")
61
+ # print(prompts)
62
+
63
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
64
+ """Constructs the noise schedule of Karras et al. (2022)."""
65
+ ramp = torch.linspace(0, 1, n)
66
+ min_inv_rho = sigma_min ** (1 / rho)
67
+ max_inv_rho = sigma_max ** (1 / rho)
68
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
69
+ return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
70
+
71
+ def extract_into_tensor(a, t, x_shape):
72
+ b, *_ = t.shape
73
+ out = a.gather(-1, t)
74
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
75
+
76
+ def append_zero(x):
77
+ return torch.cat([x, x.new_zeros([1])])
78
+
79
+ def append_dims(x, target_dims):
80
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
81
+ dims_to_append = target_dims - x.ndim
82
+ if dims_to_append < 0:
83
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
84
+ return x[(...,) + (None,) * dims_to_append]
85
+
86
+ class CFGDenoiser(nn.Module):
87
+ def __init__(self, model):
88
+ super().__init__()
89
+ self.inner_model = model
90
+
91
+
92
+ def prepare_sdxl_pipeline_step_parameter(self, pipe, prompts, need_cfg, device):
93
+ (
94
+ prompt_embeds,
95
+ negative_prompt_embeds,
96
+ pooled_prompt_embeds,
97
+ negative_pooled_prompt_embeds,
98
+ ) = pipe.encode_prompt(
99
+ prompt=prompts,
100
+ device=device,
101
+ do_classifier_free_guidance=need_cfg,
102
+ )
103
+ # timesteps = pipe.scheduler.timesteps
104
+
105
+ prompt_embeds = prompt_embeds.to(device)
106
+ add_text_embeds = pooled_prompt_embeds.to(device)
107
+ original_size = (1024, 1024)
108
+ crops_coords_top_left = (0, 0)
109
+ target_size = (1024, 1024)
110
+ text_encoder_projection_dim = None
111
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
112
+ if pipe.text_encoder_2 is None:
113
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
114
+ else:
115
+ text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
116
+ passed_add_embed_dim = (
117
+ pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
118
+ )
119
+ expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
120
+ if expected_add_embed_dim != passed_add_embed_dim:
121
+ raise ValueError(
122
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
123
+ )
124
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
125
+ add_time_ids = add_time_ids.to(device)
126
+ negative_add_time_ids = add_time_ids
127
+
128
+ if need_cfg:
129
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
130
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
131
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
132
+ ret_dict = {
133
+ "text_embeds": add_text_embeds,
134
+ "time_ids": add_time_ids
135
+ }
136
+ return prompt_embeds, ret_dict
137
+
138
+
139
+ def get_golden_noised(self, x, sigma,sigma_nxt, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = [],noise_training_list={}):
140
+ x_in = torch.cat([x] * 2)
141
+ sigma_in = torch.cat([sigma] * 2)
142
+ sigma_nxt = torch.cat([sigma_nxt] * 2)
143
+ prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
144
+ _, ret = self.inner_model.get_customed_golden_noise(x_in
145
+ , cond_scale
146
+ , sigma_in, sigma_nxt
147
+ , True
148
+ , noise_training_list=noise_training_list
149
+ , encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
150
+ , added_cond_kwargs=cond_kwargs).chunk(2)
151
+
152
+ return ret
153
+
154
+ def forward(self, x, sigma, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = []):
155
+ prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
156
+ # w = cond_scale * x.new_ones([x.shape[0]])
157
+ # w_embedding = guidance_scale_embedding(w, embedding_dim=self.inner_model.inner_model.config["time_cond_proj_dim"])
158
+ # w_embedding = w_embedding.to(device=x.device, dtype=x.dtype)
159
+ # # t = self.inner_model.sigma_to_t(sigma)
160
+ # cond = self.inner_model(
161
+ # x,
162
+ # sigma,
163
+ # timestep_cond=w_embedding,
164
+ # encoder_hidden_states=cond.to(device=x.device, dtype=x.dtype),
165
+ # )
166
+ # return cond
167
+ x_in = torch.cat([x] * 2)
168
+ sigma_in = torch.cat([sigma] * 2)
169
+ # cond_in = torch.cat([uncond, cond])
170
+ uncond, cond = self.inner_model(x_in
171
+ , sigma_in
172
+ , tmp_list
173
+ , encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
174
+ , added_cond_kwargs=cond_kwargs).chunk(2)
175
+ if need_distill_uncond:
176
+ uncond_list.append(uncond)
177
+ return prompt_embeds, uncond + (cond - uncond) * cond_scale
178
+
179
+
180
+ class DiscreteSchedule(nn.Module):
181
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
182
+ levels."""
183
+
184
+ def __init__(self, sigmas, quantize):
185
+ super().__init__()
186
+ self.register_buffer('sigmas', sigmas)
187
+ self.register_buffer('log_sigmas', sigmas.log())
188
+ self.quantize = quantize
189
+
190
+ @property
191
+ def sigma_min(self):
192
+ return self.sigmas[0]
193
+
194
+ @property
195
+ def sigma_max(self):
196
+ return self.sigmas[-1]
197
+
198
+ def get_sigmas(self, n=None):
199
+ if n is None:
200
+ return append_zero(self.sigmas.flip(0))
201
+ t_max = len(self.sigmas) - 1
202
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
203
+ return append_zero(self.t_to_sigma(t))
204
+
205
+ def sigma_to_t(self, sigma, quantize=None):
206
+ quantize = self.quantize if quantize is None else quantize
207
+ log_sigma = sigma.log()
208
+ dists = log_sigma - self.log_sigmas[:, None]
209
+ if quantize:
210
+ return dists.abs().argmin(dim=0).view(sigma.shape)
211
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
212
+ high_idx = low_idx + 1
213
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
214
+ w = (low - log_sigma) / (low - high)
215
+ w = w.clamp(0, 1)
216
+ t = (1 - w) * low_idx + w * high_idx
217
+ return t.view(sigma.shape)
218
+
219
+ def t_to_sigma(self, t):
220
+ t = t.float()
221
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
222
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
223
+ return log_sigma.exp()
224
+
225
+ class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
226
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
227
+ noise)."""
228
+
229
+ def __init__(self, pipe, alphas_cumprod, quantize = False):
230
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
231
+ self.pipe = pipe
232
+ self.inner_model = pipe.unet
233
+ # self.alphas_cumprod = alphas_cumprod.flip(0)
234
+ # Prepare a reversed version of alphas_cumprod for backward scheduling
235
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
236
+ # self.register_buffer('alphas_cumprod_prev', append_zero(alphas_cumprod[:-1]))
237
+ self.sigma_data = 1.
238
+
239
+ def get_scalings(self, sigma):
240
+ c_out = -sigma
241
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
242
+ return c_out, c_in
243
+
244
+ def get_eps(self, *args, **kwargs):
245
+ return self.inner_model(*args, **kwargs)
246
+
247
+ def get_alphact_and_sigma(self, timesteps, x_0, noise):
248
+ high_idx = torch.ceil(timesteps).int()
249
+ low_idx = torch.floor(timesteps).int()
250
+
251
+ nxt_ts = timesteps - timesteps.new_ones(timesteps.shape[0])
252
+
253
+ w = (timesteps - low_idx) / (high_idx - low_idx)
254
+
255
+ beta_1 = torch.tensor([1e-4],dtype=torch.float32)
256
+ beta_T = torch.tensor([0.02],dtype=torch.float32)
257
+ ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32)
258
+
259
+ beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
260
+ beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
261
+
262
+ alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
263
+ alpha_t_prev = beta_t.new_ones(beta_t.shape[0]) - beta_t_prev
264
+
265
+ dir_xt = (1. - alpha_t_prev).sqrt() * noise
266
+ x_prev = alpha_t_prev.sqrt() * x_0 + dir_xt + noise
267
+
268
+ alpha_cumprod_t_floor = self.alpha_cumprods[low_idx]
269
+ alpha_cumprod_t = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
270
+ sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
271
+ sigmas = torch.sqrt(alpha_cumprod_t.new_ones(alpha_cumprod_t.shape[0]) - alpha_cumprod_t)
272
+
273
+ # Fix broadcasting
274
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[:, None, None]
275
+ sigmas = sigmas[:, None, None]
276
+ return alpha_cumprod_t, sigmas
277
+
278
+ def get_c_ins(self,sigmas): # use to adjust loss
279
+ ret = list()
280
+ for sigma in sigmas:
281
+ _, c_in = self.get_scalings(sigma=sigma)
282
+ ret.append(c_in)
283
+ return ret
284
+
285
+ # def predicted_origin(model_output, timesteps, sample, alphas, sigmas, prediction_type = "epsilon"):
286
+ # if prediction_type == "epsilon":
287
+ # sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
288
+ # alphas = extract_into_tensor(alphas, timesteps, sample.shape)
289
+ # pred_x_0 = (sample - sigmas * model_output) / alphas
290
+ # elif prediction_type == "v_prediction":
291
+ # sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
292
+ # alphas = extract_into_tensor(alphas, timesteps, sample.shape)
293
+ # pred_x_0 = alphas * sample - sigmas * model_output
294
+ # else:
295
+ # raise ValueError(f"Prediction type {prediction_type} currently not supported.")
296
+ # return pred_x_0
297
+
298
+ def get_customed_golden_noise(self
299
+ , input
300
+ , unconditional_guidance_scale:float
301
+ , sigma
302
+ , sigma_nxt
303
+ , need_cond = True
304
+ , noise_training_list = {}
305
+ , **kwargs):
306
+ """User should ensure the input is a pure noise.
307
+ It's a customed golden noise, not the one purposed in the paper.
308
+ Maybe the one purposed in the paper should be implemented in the future."""
309
+ c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
310
+
311
+ sigma_fn = lambda t: t.neg().exp()
312
+ t_fn = lambda sigma: sigma.log().neg()
313
+ if need_cond:
314
+ _, tmp_img = (input * c_in).chunk(2)
315
+ else :
316
+ tmp_img = input * c_in
317
+ # print(tmp_img.max())
318
+ # tmp_list.append(tmp_img)
319
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
320
+ x_0 = input + eps * c_out
321
+ # normal_form_input = input * c_in
322
+ x_0_uncond, x_0 = x_0.chunk(2)
323
+ x_0 = x_0_uncond + unconditional_guidance_scale * (x_0 - x_0_uncond)
324
+ x_0 = torch.cat([x_0] * 2)
325
+
326
+
327
+ t, t_next = t_fn(sigma), t_fn(sigma_nxt)
328
+ h = t_next - t
329
+
330
+ x = (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim)) * input - append_dims((-h).expm1(),input.ndim) * x_0
331
+
332
+ c_out_2, c_in_2 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma_nxt)]
333
+
334
+
335
+ # e_t_uncond_ret, e_t_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample.chunk(2)
336
+ eps_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample
337
+ org_golden_noise = False
338
+ x_1 = x + eps_ret * c_out_2
339
+ if org_golden_noise:
340
+ ret = (x + append_dims((-h).expm1(),input.ndim) * x_1) / (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim))
341
+ else :
342
+ e_t_uncond_ret, e_t_ret = eps_ret.chunk(2)
343
+ e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
344
+ e_t_ret = torch.cat([e_t_ret] * 2)
345
+ ret = x_0 + e_t_ret * append_dims(sigma,input.ndim)
346
+
347
+ noise_training_list['org_noise'] = input * c_in
348
+ noise_training_list['golden_noise'] = ret * c_in
349
+ # noise_training_list.append(tmp_dict)
350
+ return ret
351
+
352
+ # timesteps = self.sigma_to_t(sigma)
353
+
354
+ # high_idx = torch.ceil(timesteps).int().to(input.device)
355
+ # low_idx = torch.floor(timesteps).int().to(input.device)
356
+
357
+ # nxt_ts = (timesteps - timesteps.new_ones(timesteps.shape[0])).to(input.device)
358
+
359
+ # w = (timesteps - low_idx) / (high_idx - low_idx)
360
+
361
+ # beta_1 = torch.tensor([1e-4],dtype=torch.float32).to(input.device)
362
+ # beta_T = torch.tensor([0.02],dtype=torch.float32).to(input.device)
363
+ # ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32).to(input.device)
364
+
365
+ # beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
366
+ # beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
367
+
368
+ # alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
369
+ # alpha_t = append_dims(alpha_t, e_t.ndim)
370
+ # alpha_t_prev = beta_t_prev.new_ones(beta_t_prev.shape[0]) - beta_t_prev
371
+ # alpha_t_prev = append_dims(alpha_t_prev, e_t.ndim)
372
+ # alpha_cumprod_t_floor = self.alphas_cumprod[low_idx]
373
+ # alpha_cumprod_t_floor = append_dims(alpha_cumprod_t_floor, e_t.ndim)
374
+ # alpha_cumprod_t:torch.Tensor = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
375
+ # alpha_cumprod_t_prev:torch.Tensor = (alpha_cumprod_t_floor * alpha_t_prev) #.unsqueeze(1)
376
+
377
+ # sqrt_one_minus_alphas = (1 - alpha_cumprod_t).sqrt()
378
+
379
+ # dir_xt = (1. - alpha_cumprod_t_prev).sqrt() * e_t
380
+ # x_prev = alpha_cumprod_t_prev.sqrt() * x_0 + dir_xt
381
+
382
+ # e_t_uncond_ret, e_t_ret = self.get_eps(x_prev, nxt_ts, **kwargs).sample.chunk(2)
383
+ # e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
384
+ # e_t_ret = torch.cat([e_t_ret] * 2)
385
+ # x_ret = alpha_t.sqrt() * x_0 + sqrt_one_minus_alphas * e_t_ret
386
+
387
+ # return x_ret
388
+
389
+ def forward(self, input, sigma, tmp_list=[], need_cond = True, **kwargs):
390
+ # c_out_1, c_in_1 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
391
+ # if need_cond:
392
+ # tmp_img = input * c_in_1
393
+ # else :
394
+ # tmp_img = input * c_in_1
395
+ # tmp_list.append(tmp_img)
396
+ # timestep = self.sigma_to_t(sigma)
397
+ # eps = self.get_eps(sample = input * c_in_1, timestep = timestep, **kwargs).sample
398
+ # c_skip, c_out = self.scalings_for_boundary_conditions(timestep=self.sigma_to_t(sigma))
399
+ # # return (input + eps * c_out_1) * c_out + input * c_in_1 * c_skip
400
+ # return (input + eps * c_out_1)
401
+ c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
402
+ if need_cond:
403
+ _, tmp_img = (input * c_in).chunk(2)
404
+ else :
405
+ tmp_img = input * c_in
406
+ # print(tmp_img.max())
407
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
408
+ tmp_x0 = input + eps * c_out
409
+ tmp_dict = {'tmp_z': tmp_img, 'tmp_x0': tmp_x0}
410
+ tmp_list.append(tmp_dict)
411
+ return tmp_x0 #input + eps * c_out
412
+
413
+ def get_special_sigmas_with_timesteps(self,timesteps):
414
+ low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
415
+ self.alphas_cumprod = self.alphas_cumprod.to('cpu')
416
+ alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
417
+ return ((1 - alphas) / alphas) ** 0.5
418
+
419
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
420
+ """Calculates the noise level (sigma_down) to step down to and the amount
421
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
422
+ if not eta:
423
+ return sigma_to, 0.
424
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
425
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
426
+ return sigma_down, sigma_up
427
+
428
+ def to_d(x, sigma, denoised):
429
+ """Converts a denoiser output to a Karras ODE derivative."""
430
+ return (x - denoised) / append_dims(sigma, x.ndim)
431
+
432
+ class BatchedBrownianTree:
433
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
434
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
435
+ t0, t1, self.sign = self.sort(t0, t1)
436
+ w0 = kwargs.get('w0', torch.zeros_like(x))
437
+ if seed is None:
438
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
439
+ self.batched = True
440
+ try:
441
+ assert len(seed) == x.shape[0]
442
+ w0 = w0[0]
443
+ except TypeError:
444
+ seed = [seed]
445
+ self.batched = False
446
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
447
+
448
+ @staticmethod
449
+ def sort(a, b):
450
+ return (a, b, 1) if a < b else (b, a, -1)
451
+
452
+ def __call__(self, t0, t1):
453
+ t0, t1, sign = self.sort(t0, t1)
454
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
455
+ return w if self.batched else w[0]
456
+
457
+ class BrownianTreeNoiseSampler:
458
+ """A noise sampler backed by a torchsde.BrownianTree.
459
+
460
+ Args:
461
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
462
+ random samples.
463
+ sigma_min (float): The low end of the valid interval.
464
+ sigma_max (float): The high end of the valid interval.
465
+ seed (int or List[int]): The random seed. If a list of seeds is
466
+ supplied instead of a single integer, then the noise sampler will
467
+ use one BrownianTree per batch item, each with its own seed.
468
+ transform (callable): A function that maps sigma to the sampler's
469
+ internal timestep.
470
+ """
471
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
472
+ self.transform = transform
473
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
474
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
475
+
476
+ def __call__(self, sigma, sigma_next):
477
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
478
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
479
+
480
+ @torch.no_grad()
481
+ def sample_euler(model
482
+ , x
483
+ , sigmas
484
+ , extra_args=None
485
+ , callback=None
486
+ , disable=None
487
+ , s_churn=0.
488
+ , s_tmin=0.
489
+ , s_tmax=float('inf')
490
+ , tmp_list=[]
491
+ , uncond_list=[]
492
+ , need_distill_uncond=False
493
+ , start_free_step = 1
494
+ , noise_training_list={}
495
+ , s_noise=1.):
496
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
497
+ extra_args = {} if extra_args is None else extra_args
498
+ s_in = x.new_ones([x.shape[0]])
499
+ intermediates = {'x_inter': [x],'pred_x0': []}
500
+ register_free_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
501
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
502
+ for i in trange(len(sigmas) - 1, disable=disable):
503
+ if i == start_free_step:
504
+ register_free_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
505
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
506
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
507
+ eps = torch.randn_like(x) * s_noise
508
+ sigma_hat = sigmas[i] * (gamma + 1)
509
+ if gamma > 0:
510
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
511
+ prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
512
+ d = to_d(x, sigma_hat, denoised)
513
+ if callback is not None:
514
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
515
+ dt = sigmas[i + 1] - sigma_hat
516
+ # Euler method
517
+ x = x + d * dt
518
+ intermediates['pred_x0'].append(denoised)
519
+ intermediates['x_inter'].append(x)
520
+ return prompt_embeds, intermediates, x
521
+
522
+ @torch.no_grad()
523
+ def sample_heun(model
524
+ , x
525
+ , sigmas
526
+ , extra_args=None
527
+ , callback=None
528
+ , disable=None
529
+ , s_churn=0.
530
+ , s_tmin=0.
531
+ , s_tmax=float('inf')
532
+ , tmp_list=[]
533
+ , uncond_list=[]
534
+ , need_distill_uncond=False
535
+ , noise_training_list={}
536
+ , s_noise=1.):
537
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
538
+ extra_args = {} if extra_args is None else extra_args
539
+ s_in = x.new_ones([x.shape[0]])
540
+ intermediates = {'x_inter': [x],'pred_x0': []}
541
+ register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
542
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
543
+ for i in trange(len(sigmas) - 1, disable=disable):
544
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
545
+ eps = torch.randn_like(x) * s_noise
546
+ sigma_hat = sigmas[i] * (gamma + 1)
547
+ if gamma > 0:
548
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
549
+ prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
550
+ d = to_d(x, sigma_hat, denoised)
551
+ if callback is not None:
552
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
553
+ dt = sigmas[i + 1] - sigma_hat
554
+ if sigmas[i + 1] == 0:
555
+ # Euler method
556
+ x = x + d * dt
557
+ else:
558
+ # Heun's method
559
+ x_2 = x + d * dt
560
+ _, denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
561
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
562
+ d_prime = (d + d_2) / 2
563
+ x = x + d_prime * dt
564
+ intermediates['pred_x0'].append(denoised_2)
565
+ intermediates['x_inter'].append(x)
566
+ return prompt_embeds, intermediates, x
567
+
568
+ @torch.no_grad()
569
+ def sample_dpmpp_ode(model
570
+ , x
571
+ , sigmas
572
+ , need_golden_noise = False
573
+ , start_free_step = 1
574
+ , extra_args=None, callback=None
575
+ , disable=None,tmp_list=[]
576
+ , need_distill_uncond=False
577
+ , need_raw_noise=False
578
+ , uncond_list=[]
579
+ , noise_training_list={}):
580
+ """DPM-Solver++."""
581
+ extra_args = {} if extra_args is None else extra_args
582
+ s_in = x.new_ones([x.shape[0]])
583
+ sigma_fn = lambda t: t.neg().exp()
584
+ t_fn = lambda sigma: sigma.log().neg()
585
+ old_denoised = None
586
+ if need_raw_noise:
587
+ x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=(sigmas[0] - 0.28) * s_in, noise_training_list=noise_training_list,**extra_args)
588
+ register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
589
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
590
+ intermediates = {'x_inter': [x],'pred_x0': []}
591
+
592
+ for i in trange(len(sigmas) - 1, disable=disable):
593
+ if i == start_free_step:
594
+ register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
595
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
596
+ # macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
597
+ prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
598
+ if callback is not None:
599
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
600
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
601
+ h = t_next - t
602
+
603
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
604
+ intermediates['pred_x0'].append(denoised)
605
+ intermediates['x_inter'].append(x)
606
+
607
+ # print(denoised_d.max())
608
+
609
+ # intermediates['noise'].append(denoised_d)
610
+ return prompt_embeds, intermediates,x
611
+
612
+ @torch.no_grad()
613
+ def sample_dpmpp_sde(model
614
+ , x
615
+ , sigmas
616
+ , need_golden_noise = False
617
+ , extra_args=None
618
+ , callback=None
619
+ , tmp_list=[]
620
+ , need_distill_uncond=False
621
+ , uncond_list=[]
622
+ , disable=None, eta=1.
623
+ , s_noise=1.
624
+ , noise_sampler=None
625
+ , r=1 / 2):
626
+ """DPM-Solver++ (stochastic)."""
627
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
628
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
629
+ extra_args = {} if extra_args is None else extra_args
630
+ s_in = x.new_ones([x.shape[0]])
631
+ sigma_fn = lambda t: t.neg().exp()
632
+ t_fn = lambda sigma: sigma.log().neg()
633
+ if need_golden_noise:
634
+ x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
635
+
636
+ intermediates = {'x_inter': [x],'pred_x0': []}
637
+
638
+ for i in trange(len(sigmas) - 1, disable=disable):
639
+ prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
640
+ if callback is not None:
641
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
642
+ if sigmas[i + 1] == 0:
643
+ # Euler method
644
+ d = to_d(x, sigmas[i], denoised)
645
+ dt = sigmas[i + 1] - sigmas[i]
646
+ x = x + d * dt
647
+ intermediates['pred_x0'].append(denoised)
648
+ intermediates['x_inter'].append(x)
649
+ else:
650
+ # DPM-Solver++
651
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
652
+ h = t_next - t
653
+ s = t + h * r
654
+ fac = 1 / (2 * r)
655
+
656
+ # Step 1
657
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
658
+ s_ = t_fn(sd)
659
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
660
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
661
+ prompt_embeds, denoised_2 = model(x_2, sigma_fn(s) * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args) #(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
662
+
663
+ # Step 2
664
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
665
+ t_next_ = t_fn(sd)
666
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
667
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
668
+ intermediates['pred_x0'].append(x)
669
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
670
+ intermediates['x_inter'].append(x)
671
+ return prompt_embeds, intermediates,x
672
+
673
+
674
+ @torch.no_grad()
675
+ def sample_dpmpp_2m(model
676
+ , x
677
+ , sigmas
678
+ # , need_golden_noise = True
679
+ , extra_args=None
680
+ , callback=None
681
+ , disable=None
682
+ , tmp_list=[]
683
+ , need_distill_uncond=False
684
+ , start_free_step=9
685
+ , uncond_list=[]
686
+ , stop_t = None):
687
+ """DPM-Solver++(2M)."""
688
+ extra_args = {} if extra_args is None else extra_args
689
+ s_in = x.new_ones([x.shape[0]])
690
+ sigma_fn = lambda t: t.neg().exp()
691
+ t_fn = lambda sigma: sigma.log().neg()
692
+ old_denoised = None
693
+ # if need_golden_noise:
694
+ # x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
695
+ intermediates = {'x_inter': [x],'pred_x0': []}
696
+ register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
697
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
698
+
699
+ for i in trange(len(sigmas) - 1, disable=disable):
700
+ if i == start_free_step and len(sigmas) > 6:
701
+ register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
702
+ register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
703
+ # else:
704
+ # register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
705
+ # register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
706
+ # macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
707
+ prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
708
+ if callback is not None:
709
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
710
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
711
+ h = t_next - t
712
+ if old_denoised is None or sigmas[i + 1] == 0:
713
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
714
+ intermediates['pred_x0'].append(denoised)
715
+ intermediates['x_inter'].append(x)
716
+ else:
717
+ h_last = t - t_fn(sigmas[i - 1])
718
+ r = h_last / h
719
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
720
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
721
+ intermediates['x_inter'].append(x)
722
+ intermediates['pred_x0'].append(denoised)
723
+ # print(denoised_d.max())
724
+ old_denoised = denoised
725
+ if i is not None and i == stop_t:
726
+ return intermediates, x
727
+ # intermediates['noise'].append(denoised_d)
728
+ return prompt_embeds, intermediates,x
729
+
730
+ # Adapted from pipelines.StableDiffusionPipeline.encode_prompt
731
+ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
732
+ captions = []
733
+ for caption in prompt_batch:
734
+ if random.random() < proportion_empty_prompts:
735
+ captions.append("")
736
+ elif isinstance(caption, str):
737
+ captions.append(caption)
738
+ elif isinstance(caption, (list, np.ndarray)):
739
+ # take a random caption if there are multiple
740
+ captions.append(random.choice(caption) if is_train else caption[0])
741
+
742
+ with torch.no_grad():
743
+ text_inputs = tokenizer(
744
+ captions,
745
+ padding="max_length",
746
+ max_length=tokenizer.model_max_length,
747
+ truncation=True,
748
+ return_tensors="pt",
749
+ )
750
+ text_input_ids = text_inputs.input_ids
751
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
752
+
753
+ return prompt_embeds
754
+
755
+ def chunk(it, size):
756
+ it = iter(it)
757
+ return iter(lambda: tuple(islice(it, size)), ())
758
+
759
+
760
+ torch_dtype = torch.float32
761
+ device = "cuda"
762
+ # pipe = StableDiffusionPipeline.from_single_file( "./counterfeit/Counterfeit-V3.0_fp32.safetensors")
763
+ repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
764
+ filename = "sdxl_vae.safetensors" # e.g., "pytorch_model.bin"
765
+ downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename,cache_dir=".")
766
+
767
+ # pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4')
768
+ vae = AutoencoderKL.from_single_file(downloaded_path, torch_dtype=torch_dtype)
769
+ vae.to('cuda')
770
+
771
+ pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0",torch_dtype=torch_dtype,vae=vae)
772
+ # pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,vae=vae)
773
+
774
+ pipe.to('cuda')
775
+ npn_net = NPNet128('SDXL', './sdxl.pth')
776
+
777
+ pipe = pipe.to(device)
778
+ register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
779
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
780
  MAX_SEED = np.iinfo(np.int32).max
781
  MAX_IMAGE_SIZE = 1024
782
+ noise_scheduler = pipe.scheduler
783
+ alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=torch_dtype)
784
+ model_wrap = DiscreteEpsDDPMDenoiser(pipe.unet, alpha_schedule, quantize=False)
785
+ accelerator = accelerate.Accelerator()
786
+
787
+ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
788
+ """Helper function to generate image with specific number of steps"""
789
+ prompts = [prompt]
790
+ if num_inference_steps <= 10:
791
+ register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
792
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
793
+ else:
794
+ register_free_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
795
+ register_free_crossattn_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
796
+ if randomize_seed:
797
+ seed = random.randint(0, MAX_SEED)
798
+
799
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
800
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
801
+ return {"prompt_embeds": prompt_embeds}
802
+
803
+ compute_embeddings_fn = functools.partial(
804
+ compute_embeddings,
805
+ proportion_empty_prompts=0,
806
+ text_encoder=pipe.text_encoder,
807
+ tokenizer=pipe.tokenizer,
808
+ )
809
+ generator = torch.Generator().manual_seed(seed)
810
+
811
+ intermediate_photos = list()
812
+ # prompts = prompts[0]
813
+
814
+ # if isinstance(prompts, tuple) or isinstance(prompts, str):
815
+ # prompts = list(prompts)
816
+ if isinstance(prompts, str):
817
+ prompts = prompts #+ 'high quality, best quality, masterpiece, 4K, highres, extremely detailed, ultra-detailed'
818
+ prompts = (prompts,)
819
+ if isinstance(prompts, tuple) or isinstance(prompts, str):
820
+ prompts = list(prompts)
821
+
822
+ shape = [4, height // 8, width // 8]
823
+ start_free_step = num_inference_steps
824
+ fir_stage_sigmas_ct = None
825
+ sec_stage_sigmas_ct = None
826
+ # sigmas = model_wrap.get_sigmas(opt.ddim_steps).to(device=device)
827
+ if num_inference_steps == 5:
828
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
829
+ sigmas = get_sigmas_karras(8, sigma_min, sigma_max, rho=5.0, device=device)# 6.0 if 5 else 10 10.0
830
+
831
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
832
+ # sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
833
+ ct = get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
834
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
835
+ elif num_inference_steps == 6:
836
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
837
+ sigmas = get_sigmas_karras(8, sigma_min, sigma_max,rho=5.0, device=device)# 6.0 if 5 else 10.0
838
+
839
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
840
+ # sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
841
+ ct = get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
842
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
843
+ start_free_step = 6
844
+ fir_stage_sigmas_ct = sigmas_ct[:-2]
845
+ sec_stage_sigmas_ct = sigmas_ct[-3:]
846
+ elif num_inference_steps == 8:
847
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
848
+
849
+ sigmas = get_sigmas_karras(12, sigma_min, sigma_max,rho=12.0, device=device)# 6.0 if 5 else 10.0
850
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[10])
851
+
852
+ ct = get_sigmas_karras(num_inference_steps + 1, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
853
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
854
+ start_free_step = 8
855
+ else:
856
+ image = pipe(prompt=prompts
857
+ ,num_inference_steps=num_inference_steps
858
+ ,guidance_scale=guidance_scale
859
+ ,height=height
860
+ ,width=width).images[0]
861
+ return image
862
+ ts = []
863
+ for sigma in sigmas_ct:
864
+ t = model_wrap.sigma_to_t(sigma)
865
+ ts.append(t)
866
+
867
+ c_in = model_wrap.get_c_ins(sigmas=sigmas_ct)
868
+ x = torch.randn([1, *shape], device=device) * sigmas_ct[0]
869
+ model_wrap_cfg = CFGDenoiser(model_wrap)
870
+ (
871
+ c,
872
+ uc,
873
+ _,
874
+ _,
875
+ ) = pipe.encode_prompt(
876
+ prompt=prompts,
877
+ device=device,
878
+ do_classifier_free_guidance=True,
879
+ )
880
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
881
+
882
+ # if (num_inference_steps != -1 or num_inference_steps <= 8) and not opt.force_not_use_NPNet:
883
+ x = npn_net(x,c)
884
+
885
+ extra_args = {'prompt': prompts, 'cond_scale': guidance_scale}
886
+ with torch.no_grad():
887
+ # with precision_scope("cuda" if torch.cuda.is_available() else "cpu"):
888
+ if not (num_inference_steps == 8 or num_inference_steps == 7):
889
+ prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_ode(model_wrap_cfg
890
+ , x
891
+ , fir_stage_sigmas_ct
892
+ , extra_args=extra_args
893
+ , disable=not accelerator.is_main_process
894
+ , need_raw_noise = False
895
+ , tmp_list=intermediate_photos)
896
+ _, _, samples_ddim = sample_euler(model_wrap_cfg
897
+ , samples_ddim
898
+ , sec_stage_sigmas_ct
899
+ , extra_args=extra_args
900
+ , disable=not accelerator.is_main_process
901
+ , s_noise = 0.3
902
+ , tmp_list=intermediate_photos)
903
+ else:
904
+ prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_2m(model_wrap_cfg
905
+ , x
906
+ , sigmas_ct
907
+ , extra_args=extra_args
908
+ , start_free_step=start_free_step
909
+ , disable=not accelerator.is_main_process
910
+ , tmp_list=intermediate_photos)
911
+ # print('2m')
912
+
913
+ x_samples_ddim = pipe.vae.decode(samples_ddim / pipe.vae.config.scaling_factor).sample
914
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
915
 
916
+ if True: # not opt.skip_save:
917
+ for x_sample in x_samples_ddim:
918
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
919
+ image = Image.fromarray(x_sample.astype(np.uint8))
920
+ # base_count += 1
921
+
922
+ # image = pipe(
923
+ # prompt=prompt,
924
+ # negative_prompt=negative_prompt,
925
+ # guidance_scale=guidance_scale,
926
+ # num_inference_steps=num_inference_steps,
927
+ # width=width,
928
+ # height=height,
929
+ # generator=generator,
930
+ # ).images[0]
931
 
932
+ return image
933
+
934
+ @spaces.GPU #[uncomment to use ZeroGPU]
935
  def infer(
936
  prompt,
937
  negative_prompt,
938
  seed,
939
  randomize_seed,
940
+ resolution,
 
941
  guidance_scale,
942
  num_inference_steps,
943
  progress=gr.Progress(track_tqdm=True),
944
  ):
945
  if randomize_seed:
946
  seed = random.randint(0, MAX_SEED)
947
+
948
+ # Parse resolution string into width and height
949
+ width, height = map(int, resolution.split('x'))
950
+
951
+ # Generate image with selected steps
952
+ image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
953
+
954
+ # Generate image with 50 steps for high quality
955
+ image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
956
 
957
+ return image_quick, image_50_steps, seed
 
 
 
 
 
 
 
 
 
 
 
 
958
 
959
 
960
  examples = [
961
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
962
+ "a painting of a virus monster playing guitar",
963
+ "a painting of a squirrel eating a burger",
964
  ]
965
 
966
  css = """
 
972
 
973
  with gr.Blocks(css=css) as demo:
974
  with gr.Column(elem_id="col-container"):
975
+ gr.Markdown(" # Hyperparameters are all you need")
976
 
977
  with gr.Row():
978
  prompt = gr.Text(
 
985
 
986
  run_button = gr.Button("Run", scale=0, variant="primary")
987
 
988
+ with gr.Row():
989
+ with gr.Column():
990
+ gr.Markdown("### Our fast inference Result")
991
+ result = gr.Image(label="Quick Result", show_label=False)
992
+ with gr.Column():
993
+ gr.Markdown("### Original 50 steps Result")
994
+ result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
995
 
996
  with gr.Accordion("Advanced Settings", open=False):
997
  negative_prompt = gr.Text(
 
1011
 
1012
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
1013
 
1014
+ resolution = gr.Dropdown(
1015
+ choices=[
1016
+ "1024x1024",
1017
+ "1216x832",
1018
+ "832x1216"
1019
+ ],
1020
+ value="1024x1024",
1021
+ label="Resolution",
1022
+ )
 
 
 
 
 
 
 
1023
 
1024
  with gr.Row():
1025
  guidance_scale = gr.Slider(
 
1027
  minimum=0.0,
1028
  maximum=10.0,
1029
  step=0.1,
1030
+ value=7.5, # Replace with defaults that work for your model
1031
  )
1032
 
1033
+ num_inference_steps = gr.Dropdown(
1034
+ choices=[6, 8],
1035
+ value=8,
1036
  label="Number of inference steps",
 
 
 
 
1037
  )
1038
 
1039
  gr.Examples(examples=examples, inputs=[prompt])
 
1045
  negative_prompt,
1046
  seed,
1047
  randomize_seed,
1048
+ resolution,
 
1049
  guidance_scale,
1050
  num_inference_steps,
1051
  ],
1052
+ outputs=[result, result_50_steps, seed],
1053
  )
1054
 
1055
  if __name__ == "__main__":
free_lunch_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.utils import is_torch_version
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+
7
+ def isinstance_str(x: object, cls_name: str):
8
+ """
9
+ Checks whether x has any class *named* cls_name in its ancestry.
10
+ Doesn't require access to the class's implementation.
11
+
12
+ Useful for patching!
13
+ """
14
+
15
+ for _cls in x.__class__.__mro__:
16
+ if _cls.__name__ == cls_name:
17
+ return True
18
+
19
+ return False
20
+
21
+
22
+ def Fourier_filter(x, threshold, scale):
23
+ dtype = x.dtype
24
+ x = x.type(torch.float32)
25
+ # FFT
26
+ x_freq = fft.fftn(x, dim=(-2, -1))
27
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
28
+
29
+ B, C, H, W = x_freq.shape
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ mask = torch.ones((B, C, H, W)).to(device=device)
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
+
44
+
45
+ def register_upblock2d(model):
46
+ def up_forward(self):
47
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
48
+ for resnet in self.resnets:
49
+ # pop res hidden states
50
+ res_hidden_states = res_hidden_states_tuple[-1]
51
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
+
55
+ if self.training and self.gradient_checkpointing:
56
+
57
+ def create_custom_forward(module):
58
+ def custom_forward(*inputs):
59
+ return module(*inputs)
60
+
61
+ return custom_forward
62
+
63
+ if is_torch_version(">=", "1.11.0"):
64
+ hidden_states = torch.utils.checkpoint.checkpoint(
65
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
+ )
67
+ else:
68
+ hidden_states = torch.utils.checkpoint.checkpoint(
69
+ create_custom_forward(resnet), hidden_states, temb
70
+ )
71
+ else:
72
+ hidden_states = resnet(hidden_states, temb)
73
+
74
+ if self.upsamplers is not None:
75
+ for upsampler in self.upsamplers:
76
+ hidden_states = upsampler(hidden_states, upsample_size)
77
+
78
+ return hidden_states
79
+
80
+ return forward
81
+
82
+ for i, upsample_block in enumerate(model.unet.up_blocks):
83
+ if isinstance_str(upsample_block, "UpBlock2D"):
84
+ upsample_block.forward = up_forward(upsample_block)
85
+
86
+
87
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
88
+ def up_forward(self):
89
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
90
+ for resnet in self.resnets:
91
+ # pop res hidden states
92
+ res_hidden_states = res_hidden_states_tuple[-1]
93
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
+
96
+ # --------------- FreeU code -----------------------
97
+ # Only operate on the first two stages
98
+ if hidden_states.shape[1] == 1280:
99
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
+ if hidden_states.shape[1] == 640:
102
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
+ # ---------------------------------------------------------
105
+
106
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
107
+
108
+ if self.training and self.gradient_checkpointing:
109
+
110
+ def create_custom_forward(module):
111
+ def custom_forward(*inputs):
112
+ return module(*inputs)
113
+
114
+ return custom_forward
115
+
116
+ if is_torch_version(">=", "1.11.0"):
117
+ hidden_states = torch.utils.checkpoint.checkpoint(
118
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
119
+ )
120
+ else:
121
+ hidden_states = torch.utils.checkpoint.checkpoint(
122
+ create_custom_forward(resnet), hidden_states, temb
123
+ )
124
+ else:
125
+ hidden_states = resnet(hidden_states, temb)
126
+
127
+ if self.upsamplers is not None:
128
+ for upsampler in self.upsamplers:
129
+ hidden_states = upsampler(hidden_states, upsample_size)
130
+
131
+ return hidden_states
132
+
133
+ return forward
134
+
135
+ for i, upsample_block in enumerate(model.unet.up_blocks):
136
+ if isinstance_str(upsample_block, "UpBlock2D"):
137
+ upsample_block.forward = up_forward(upsample_block)
138
+ setattr(upsample_block, 'b1', b1)
139
+ setattr(upsample_block, 'b2', b2)
140
+ setattr(upsample_block, 's1', s1)
141
+ setattr(upsample_block, 's2', s2)
142
+
143
+
144
+ def register_crossattn_upblock2d(model):
145
+ def up_forward(self):
146
+ def forward(
147
+ hidden_states: torch.FloatTensor,
148
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
149
+ temb: Optional[torch.FloatTensor] = None,
150
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
151
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
152
+ upsample_size: Optional[int] = None,
153
+ attention_mask: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ):
156
+ for resnet, attn in zip(self.resnets, self.attentions):
157
+ # pop res hidden states
158
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
159
+ res_hidden_states = res_hidden_states_tuple[-1]
160
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
161
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
162
+
163
+ if self.training and self.gradient_checkpointing:
164
+
165
+ def create_custom_forward(module, return_dict=None):
166
+ def custom_forward(*inputs):
167
+ if return_dict is not None:
168
+ return module(*inputs, return_dict=return_dict)
169
+ else:
170
+ return module(*inputs)
171
+
172
+ return custom_forward
173
+
174
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
175
+ hidden_states = torch.utils.checkpoint.checkpoint(
176
+ create_custom_forward(resnet),
177
+ hidden_states,
178
+ temb,
179
+ **ckpt_kwargs,
180
+ )
181
+ hidden_states = torch.utils.checkpoint.checkpoint(
182
+ create_custom_forward(attn, return_dict=False),
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ None, # timestep
186
+ None, # class_labels
187
+ cross_attention_kwargs,
188
+ attention_mask,
189
+ encoder_attention_mask,
190
+ **ckpt_kwargs,
191
+ )[0]
192
+ else:
193
+ hidden_states = resnet(hidden_states, temb)
194
+ hidden_states = attn(
195
+ hidden_states,
196
+ encoder_hidden_states=encoder_hidden_states,
197
+ cross_attention_kwargs=cross_attention_kwargs,
198
+ attention_mask=attention_mask,
199
+ encoder_attention_mask=encoder_attention_mask,
200
+ return_dict=False,
201
+ )[0]
202
+
203
+ if self.upsamplers is not None:
204
+ for upsampler in self.upsamplers:
205
+ hidden_states = upsampler(hidden_states, upsample_size)
206
+
207
+ return hidden_states
208
+
209
+ return forward
210
+
211
+ for i, upsample_block in enumerate(model.unet.up_blocks):
212
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
213
+ upsample_block.forward = up_forward(upsample_block)
214
+
215
+
216
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
217
+ def up_forward(self):
218
+ def forward(
219
+ hidden_states: torch.FloatTensor,
220
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
221
+ temb: Optional[torch.FloatTensor] = None,
222
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
223
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
224
+ upsample_size: Optional[int] = None,
225
+ attention_mask: Optional[torch.FloatTensor] = None,
226
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
227
+ ):
228
+ for resnet, attn in zip(self.resnets, self.attentions):
229
+ # pop res hidden states
230
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
231
+ res_hidden_states = res_hidden_states_tuple[-1]
232
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
233
+
234
+ # --------------- FreeU code -----------------------
235
+ # Only operate on the first two stages
236
+ if hidden_states.shape[1] == 1280:
237
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
+ if hidden_states.shape[1] == 640:
240
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
+ # ---------------------------------------------------------
243
+
244
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
245
+
246
+ if self.training and self.gradient_checkpointing:
247
+
248
+ def create_custom_forward(module, return_dict=None):
249
+ def custom_forward(*inputs):
250
+ if return_dict is not None:
251
+ return module(*inputs, return_dict=return_dict)
252
+ else:
253
+ return module(*inputs)
254
+
255
+ return custom_forward
256
+
257
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
258
+ hidden_states = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(resnet),
260
+ hidden_states,
261
+ temb,
262
+ **ckpt_kwargs,
263
+ )
264
+ hidden_states = torch.utils.checkpoint.checkpoint(
265
+ create_custom_forward(attn, return_dict=False),
266
+ hidden_states,
267
+ encoder_hidden_states,
268
+ None, # timestep
269
+ None, # class_labels
270
+ cross_attention_kwargs,
271
+ attention_mask,
272
+ encoder_attention_mask,
273
+ **ckpt_kwargs,
274
+ )[0]
275
+ else:
276
+ hidden_states = resnet(hidden_states, temb)
277
+ # hidden_states = attn(
278
+ # hidden_states,
279
+ # encoder_hidden_states=encoder_hidden_states,
280
+ # cross_attention_kwargs=cross_attention_kwargs,
281
+ # encoder_attention_mask=encoder_attention_mask,
282
+ # return_dict=False,
283
+ # )[0]
284
+ hidden_states = attn(
285
+ hidden_states,
286
+ encoder_hidden_states=encoder_hidden_states,
287
+ cross_attention_kwargs=cross_attention_kwargs,
288
+ )[0]
289
+
290
+ if self.upsamplers is not None:
291
+ for upsampler in self.upsamplers:
292
+ hidden_states = upsampler(hidden_states, upsample_size)
293
+
294
+ return hidden_states
295
+
296
+ return forward
297
+
298
+ for i, upsample_block in enumerate(model.unet.up_blocks):
299
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
300
+ upsample_block.forward = up_forward(upsample_block)
301
+ setattr(upsample_block, 'b1', b1)
302
+ setattr(upsample_block, 'b2', b2)
303
+ setattr(upsample_block, 's1', s1)
304
+ setattr(upsample_block, 's2', s2)
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
1
+ tqdm
2
+ einops
3
+ pytorch_lightning
4
+ accelerate>=0.20.0
5
+ torchsde
6
+ pycocotools
7
+ diffusers== 0.32.2
8
+ timm
9
+ transformers==4.49
10
+ torch>=2.0.0
sdxl.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a8629a4d939f9ed8f02ed2ad39b8317b701fb9e59d175ce186512e4a2687e48
3
+ size 121965599