Spaces:
Runtime error
Runtime error
Commit
·
1c3f916
1
Parent(s):
09c8989
xl version
Browse files- NoiseTransformer.py +26 -0
- README.md +17 -3
- SVDNoiseUnet.py +430 -0
- app.py +958 -56
- free_lunch_utils.py +304 -0
- requirements.txt +10 -6
- sdxl.pth +3 -0
NoiseTransformer.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from timm import create_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ['NoiseTransformer']
|
| 8 |
+
|
| 9 |
+
class NoiseTransformer(nn.Module):
|
| 10 |
+
def __init__(self, resolution=(128,96)):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.upsample = lambda x: F.interpolate(x, [224,224])
|
| 13 |
+
self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
|
| 14 |
+
self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 15 |
+
self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
|
| 16 |
+
# self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 17 |
+
self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def forward(self, x, residual=False):
|
| 21 |
+
if residual:
|
| 22 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
|
| 23 |
+
else:
|
| 24 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
|
| 25 |
+
|
| 26 |
+
return x
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title: Hyperparameters
|
| 3 |
emoji: 🖼
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: red
|
|
@@ -8,7 +8,21 @@ sdk_version: 5.44.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description: training-free
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Hyperparameters-are-all-you-need-xl-version-improved-implementation
|
| 3 |
emoji: 🖼
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: red
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
short_description: training-free few step diffusion ODE solver
|
| 12 |
+
header: default
|
| 13 |
---
|
| 14 |
|
| 15 |
+
**Abstract:** The diffusion model is a state-of-the-art generative model that generates an image by applying a neural network iteratively. Moreover, this generation process is regarded as an algorithm solving an ordinary differential equation or a stochastic differential equation. Based on the analysis of the truncation error of the diffusion ODE and SDE, our study proposes a training-free algorithm that generates high-quality 512 x 512 and 1024 x 1024 images in eight steps, with flexible guidance scales. To the best of my knowledge, our algorithm is the first one that samples a 1024 x 1024 resolution image in 8 steps with an FID performance comparable to that of the latest distillation model, but without additional training. Meanwhile, our algorithm can also generate a 512 x 512 image in 8 steps, and its FID performance is better than the inference result using state-of-the-art ODE solver DPM++ 2m in 20 steps. We validate our eight-step image generation algorithm using the COCO 2014, COCO 2017, and LAION datasets. And our best FID performance is 15.7, 22.35, and 17.52. While the FID performance of DPM++2m is 17.3, 23.75, and 17.33. Further, it also outperforms the state-of-the-art AMED-plugin solver, whose FID performance is 19.07, 25.50, and 18.06. We also apply the algorithm in five-step inference without additional training, for which the best FID performance in the datasets mentioned above is 19.18, 23.24, and 19.61, respectively, and is comparable to the performance of the state-of-the-art AMED Pulgin solver in eight steps, SDXL-turbo in four steps, and the state-of-the-art diffusion distillation model Flash Diffusion in five steps. We also validate our algorithm in synthesizing 1024 * 1024 images within 6 steps, whose FID performance only has a limited distance to the latest distillation algorithm.
|
| 16 |
+
|
| 17 |
+
This is a demo is a simplified version of the approach described in the paper, ["Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model"](https://arxiv.org/abs/2510.02390)
|
| 18 |
+
|
| 19 |
+
```
|
| 20 |
+
@misc{hyper,
|
| 21 |
+
title={Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model},
|
| 22 |
+
author={Zilai Li},
|
| 23 |
+
year={2025},
|
| 24 |
+
eprint={2510.02390},
|
| 25 |
+
archivePrefix={arXiv},
|
| 26 |
+
primaryClass={eess.IV}
|
| 27 |
+
}
|
| 28 |
+
```
|
SVDNoiseUnet.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import einops
|
| 4 |
+
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.jit import Final
|
| 7 |
+
from timm.layers import use_fused_attn
|
| 8 |
+
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
|
| 9 |
+
from abc import abstractmethod
|
| 10 |
+
from NoiseTransformer import NoiseTransformer
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
__all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
|
| 13 |
+
|
| 14 |
+
class Attention(nn.Module):
|
| 15 |
+
fused_attn: Final[bool]
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
num_heads: int = 8,
|
| 21 |
+
qkv_bias: bool = False,
|
| 22 |
+
qk_norm: bool = False,
|
| 23 |
+
attn_drop: float = 0.,
|
| 24 |
+
proj_drop: float = 0.,
|
| 25 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
self.head_dim = dim // num_heads
|
| 31 |
+
self.scale = self.head_dim ** -0.5
|
| 32 |
+
self.fused_attn = use_fused_attn()
|
| 33 |
+
|
| 34 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 35 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 36 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 37 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 38 |
+
self.proj = nn.Linear(dim, dim)
|
| 39 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
B, N, C = x.shape
|
| 43 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 44 |
+
q, k, v = qkv.unbind(0)
|
| 45 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 46 |
+
|
| 47 |
+
if self.fused_attn:
|
| 48 |
+
x = F.scaled_dot_product_attention(
|
| 49 |
+
q, k, v,
|
| 50 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
q = q * self.scale
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
attn = attn.softmax(dim=-1)
|
| 56 |
+
attn = self.attn_drop(attn)
|
| 57 |
+
x = attn @ v
|
| 58 |
+
|
| 59 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SVDNoiseUnet(nn.Module):
|
| 66 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
|
| 67 |
+
super(SVDNoiseUnet, self).__init__()
|
| 68 |
+
|
| 69 |
+
_in_1 = int(resolution[0] * in_channels // 2)
|
| 70 |
+
_out_1 = int(resolution[0] * out_channels // 2)
|
| 71 |
+
|
| 72 |
+
_in_2 = int(resolution[1] * in_channels // 2)
|
| 73 |
+
_out_2 = int(resolution[1] * out_channels // 2)
|
| 74 |
+
self.mlp1 = nn.Sequential(
|
| 75 |
+
nn.Linear(_in_1, 64),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.Linear(64, _out_1),
|
| 78 |
+
)
|
| 79 |
+
self.mlp2 = nn.Sequential(
|
| 80 |
+
nn.Linear(_in_2, 64),
|
| 81 |
+
nn.ReLU(inplace=True),
|
| 82 |
+
nn.Linear(64, _out_2),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.mlp3 = nn.Sequential(
|
| 86 |
+
nn.Linear(_in_2, _out_2),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.attention = Attention(_out_2)
|
| 90 |
+
|
| 91 |
+
self.bn = nn.BatchNorm1d(256)
|
| 92 |
+
self.bn2 = nn.BatchNorm1d(192)
|
| 93 |
+
|
| 94 |
+
self.mlp4 = nn.Sequential(
|
| 95 |
+
nn.Linear(_out_2, 1024),
|
| 96 |
+
nn.ReLU(inplace=True),
|
| 97 |
+
nn.Linear(1024, _out_2),
|
| 98 |
+
)
|
| 99 |
+
self.ffn = nn.Sequential(
|
| 100 |
+
nn.Linear(256, 384), # Expand
|
| 101 |
+
nn.ReLU(inplace=True),
|
| 102 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 103 |
+
)
|
| 104 |
+
self.ffn2 = nn.Sequential(
|
| 105 |
+
nn.Linear(256, 384), # Expand
|
| 106 |
+
nn.ReLU(inplace=True),
|
| 107 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 108 |
+
)
|
| 109 |
+
# self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
|
| 110 |
+
|
| 111 |
+
def forward(self, x, residual=False):
|
| 112 |
+
b, c, h, w = x.shape
|
| 113 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 114 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 115 |
+
U_T = U.permute(0, 2, 1)
|
| 116 |
+
U_out = self.ffn(self.mlp1(U_T))
|
| 117 |
+
U_out = self.bn(U_out)
|
| 118 |
+
U_out = U_out.transpose(1, 2)
|
| 119 |
+
U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
|
| 120 |
+
U_out = self.bn2(U_out)
|
| 121 |
+
U_out = U_out.transpose(1, 2)
|
| 122 |
+
# U_out = self.bn(U_out)
|
| 123 |
+
V_out = self.mlp2(V)
|
| 124 |
+
s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 125 |
+
out = U_out + V_out + s_out
|
| 126 |
+
# print(out.size())
|
| 127 |
+
out = out.squeeze(1)
|
| 128 |
+
out = self.attention(out).mean(1)
|
| 129 |
+
out = self.mlp4(out) + s
|
| 130 |
+
diagonal_out = torch.diag_embed(out)
|
| 131 |
+
padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
|
| 132 |
+
pred = U @ padded_diag @ V
|
| 133 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 134 |
+
|
| 135 |
+
class SVDNoiseUnet64(nn.Module):
|
| 136 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
|
| 137 |
+
super(SVDNoiseUnet64, self).__init__()
|
| 138 |
+
|
| 139 |
+
_in = int(resolution * in_channels // 2)
|
| 140 |
+
_out = int(resolution * out_channels // 2)
|
| 141 |
+
self.mlp1 = nn.Sequential(
|
| 142 |
+
nn.Linear(_in, 64),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
nn.Linear(64, _out),
|
| 145 |
+
)
|
| 146 |
+
self.mlp2 = nn.Sequential(
|
| 147 |
+
nn.Linear(_in, 64),
|
| 148 |
+
nn.ReLU(inplace=True),
|
| 149 |
+
nn.Linear(64, _out),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.mlp3 = nn.Sequential(
|
| 153 |
+
nn.Linear(_in, _out),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.attention = Attention(_out)
|
| 157 |
+
|
| 158 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 159 |
+
|
| 160 |
+
self.mlp4 = nn.Sequential(
|
| 161 |
+
nn.Linear(_out, 1024),
|
| 162 |
+
nn.ReLU(inplace=True),
|
| 163 |
+
nn.Linear(1024, _out),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, x, residual=False):
|
| 167 |
+
b, c, h, w = x.shape
|
| 168 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 169 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 170 |
+
U_T = U.permute(0, 2, 1)
|
| 171 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 172 |
+
out = self.attention(out).mean(1)
|
| 173 |
+
out = self.mlp4(out) + s
|
| 174 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 175 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class SVDNoiseUnet128(nn.Module):
|
| 180 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
|
| 181 |
+
super(SVDNoiseUnet128, self).__init__()
|
| 182 |
+
|
| 183 |
+
_in = int(resolution * in_channels // 2)
|
| 184 |
+
_out = int(resolution * out_channels // 2)
|
| 185 |
+
self.mlp1 = nn.Sequential(
|
| 186 |
+
nn.Linear(_in, 64),
|
| 187 |
+
nn.ReLU(inplace=True),
|
| 188 |
+
nn.Linear(64, _out),
|
| 189 |
+
)
|
| 190 |
+
self.mlp2 = nn.Sequential(
|
| 191 |
+
nn.Linear(_in, 64),
|
| 192 |
+
nn.ReLU(inplace=True),
|
| 193 |
+
nn.Linear(64, _out),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.mlp3 = nn.Sequential(
|
| 197 |
+
nn.Linear(_in, _out),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.attention = Attention(_out)
|
| 201 |
+
|
| 202 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 203 |
+
|
| 204 |
+
self.mlp4 = nn.Sequential(
|
| 205 |
+
nn.Linear(_out, 1024),
|
| 206 |
+
nn.ReLU(inplace=True),
|
| 207 |
+
nn.Linear(1024, _out),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def forward(self, x, residual=False):
|
| 211 |
+
b, c, h, w = x.shape
|
| 212 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 213 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 214 |
+
U_T = U.permute(0, 2, 1)
|
| 215 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 216 |
+
out = self.attention(out).mean(1)
|
| 217 |
+
out = self.mlp4(out) + s
|
| 218 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 219 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SVDNoiseUnet_Concise(nn.Module):
|
| 224 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64):
|
| 225 |
+
super(SVDNoiseUnet_Concise, self).__init__()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
from diffusers.models.normalization import AdaGroupNorm
|
| 229 |
+
|
| 230 |
+
class NPNet(nn.Module):
|
| 231 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 232 |
+
super(NPNet, self).__init__()
|
| 233 |
+
|
| 234 |
+
assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
|
| 235 |
+
|
| 236 |
+
self.model_id = model_id
|
| 237 |
+
self.device = device
|
| 238 |
+
self.pretrained_path = pretrained_path
|
| 239 |
+
|
| 240 |
+
(
|
| 241 |
+
self.unet_svd,
|
| 242 |
+
self.unet_embedding,
|
| 243 |
+
self.text_embedding,
|
| 244 |
+
self._alpha,
|
| 245 |
+
self._beta
|
| 246 |
+
) = self.get_model()
|
| 247 |
+
def save_model(self, save_path: str):
|
| 248 |
+
"""
|
| 249 |
+
Save this NPNet so that get_model() can later reload it.
|
| 250 |
+
"""
|
| 251 |
+
torch.save({
|
| 252 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 253 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 254 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 255 |
+
"alpha": self._alpha,
|
| 256 |
+
"beta": self._beta,
|
| 257 |
+
}, save_path)
|
| 258 |
+
print(f"NPNet saved to {save_path}")
|
| 259 |
+
def get_model(self):
|
| 260 |
+
|
| 261 |
+
unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 262 |
+
unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 263 |
+
|
| 264 |
+
if self.model_id == 'DiT':
|
| 265 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 266 |
+
else:
|
| 267 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 268 |
+
|
| 269 |
+
# initialize random _alpha and _beta when no checkpoint is provided
|
| 270 |
+
_alpha = torch.randn(1, device=self.device)
|
| 271 |
+
_beta = torch.randn(1, device=self.device)
|
| 272 |
+
|
| 273 |
+
if '.pth' in self.pretrained_path:
|
| 274 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 275 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
|
| 276 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
|
| 277 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
|
| 278 |
+
_alpha = gloden_unet["alpha"]
|
| 279 |
+
_beta = gloden_unet["beta"]
|
| 280 |
+
|
| 281 |
+
print("Load Successfully!")
|
| 282 |
+
|
| 283 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 290 |
+
|
| 291 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 292 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 293 |
+
|
| 294 |
+
encoder_hidden_states_svd = initial_noise
|
| 295 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 296 |
+
|
| 297 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 298 |
+
|
| 299 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 300 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 301 |
+
|
| 302 |
+
return golden_noise
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class NPNet64(nn.Module):
|
| 306 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 307 |
+
super(NPNet64, self).__init__()
|
| 308 |
+
self.model_id = model_id
|
| 309 |
+
self.device = device
|
| 310 |
+
self.pretrained_path = pretrained_path
|
| 311 |
+
|
| 312 |
+
(
|
| 313 |
+
self.unet_svd,
|
| 314 |
+
self.unet_embedding,
|
| 315 |
+
self.text_embedding,
|
| 316 |
+
self._alpha,
|
| 317 |
+
self._beta
|
| 318 |
+
) = self.get_model()
|
| 319 |
+
|
| 320 |
+
def save_model(self, save_path: str):
|
| 321 |
+
"""
|
| 322 |
+
Save this NPNet so that get_model() can later reload it.
|
| 323 |
+
"""
|
| 324 |
+
torch.save({
|
| 325 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 326 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 327 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 328 |
+
"alpha": self._alpha,
|
| 329 |
+
"beta": self._beta,
|
| 330 |
+
}, save_path)
|
| 331 |
+
print(f"NPNet saved to {save_path}")
|
| 332 |
+
|
| 333 |
+
def get_model(self):
|
| 334 |
+
|
| 335 |
+
unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
|
| 336 |
+
unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
|
| 337 |
+
_alpha = torch.randn(1, device=self.device)
|
| 338 |
+
_beta = torch.randn(1, device=self.device)
|
| 339 |
+
|
| 340 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if '.pth' in self.pretrained_path:
|
| 344 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 345 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 346 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 347 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 348 |
+
_alpha = gloden_unet["alpha"]
|
| 349 |
+
_beta = gloden_unet["beta"]
|
| 350 |
+
|
| 351 |
+
print("Load Successfully!")
|
| 352 |
+
|
| 353 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 357 |
+
|
| 358 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 359 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 360 |
+
|
| 361 |
+
encoder_hidden_states_svd = initial_noise
|
| 362 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 363 |
+
|
| 364 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 365 |
+
|
| 366 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 367 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 368 |
+
|
| 369 |
+
return golden_noise
|
| 370 |
+
|
| 371 |
+
class NPNet128(nn.Module):
|
| 372 |
+
def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
|
| 373 |
+
super(NPNet128, self).__init__()
|
| 374 |
+
|
| 375 |
+
assert model_id in ['SDXL', 'DreamShaper', 'DiT']
|
| 376 |
+
|
| 377 |
+
self.model_id = model_id
|
| 378 |
+
self.device = device
|
| 379 |
+
self.pretrained_path = pretrained_path
|
| 380 |
+
|
| 381 |
+
(
|
| 382 |
+
self.unet_svd,
|
| 383 |
+
self.unet_embedding,
|
| 384 |
+
self.text_embedding,
|
| 385 |
+
self._alpha,
|
| 386 |
+
self._beta
|
| 387 |
+
) = self.get_model()
|
| 388 |
+
|
| 389 |
+
def get_model(self):
|
| 390 |
+
|
| 391 |
+
unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
|
| 392 |
+
unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
|
| 393 |
+
|
| 394 |
+
if self.model_id == 'DiT':
|
| 395 |
+
text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 396 |
+
else:
|
| 397 |
+
text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if '.pth' in self.pretrained_path:
|
| 401 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 402 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 403 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 404 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 405 |
+
_alpha = gloden_unet["alpha"]
|
| 406 |
+
_beta = gloden_unet["beta"]
|
| 407 |
+
|
| 408 |
+
print("Load Successfully!")
|
| 409 |
+
|
| 410 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 411 |
+
|
| 412 |
+
else:
|
| 413 |
+
assert ("No Pretrained Weights Found!")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 417 |
+
|
| 418 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 419 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 420 |
+
|
| 421 |
+
encoder_hidden_states_svd = initial_noise
|
| 422 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 423 |
+
|
| 424 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 425 |
+
|
| 426 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 427 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 428 |
+
|
| 429 |
+
return golden_noise
|
| 430 |
+
|
app.py
CHANGED
|
@@ -1,60 +1,966 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import random
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
-
model_repo_id = "
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
| 18 |
-
pipe = pipe.to(device)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
MAX_SEED = np.iinfo(np.int32).max
|
| 21 |
MAX_IMAGE_SIZE = 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
def infer(
|
| 26 |
prompt,
|
| 27 |
negative_prompt,
|
| 28 |
seed,
|
| 29 |
randomize_seed,
|
| 30 |
-
|
| 31 |
-
height,
|
| 32 |
guidance_scale,
|
| 33 |
num_inference_steps,
|
| 34 |
progress=gr.Progress(track_tqdm=True),
|
| 35 |
):
|
| 36 |
if randomize_seed:
|
| 37 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
image = pipe(
|
| 42 |
-
prompt=prompt,
|
| 43 |
-
negative_prompt=negative_prompt,
|
| 44 |
-
guidance_scale=guidance_scale,
|
| 45 |
-
num_inference_steps=num_inference_steps,
|
| 46 |
-
width=width,
|
| 47 |
-
height=height,
|
| 48 |
-
generator=generator,
|
| 49 |
-
).images[0]
|
| 50 |
-
|
| 51 |
-
return image, seed
|
| 52 |
|
| 53 |
|
| 54 |
examples = [
|
| 55 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
]
|
| 59 |
|
| 60 |
css = """
|
|
@@ -66,7 +972,7 @@ css = """
|
|
| 66 |
|
| 67 |
with gr.Blocks(css=css) as demo:
|
| 68 |
with gr.Column(elem_id="col-container"):
|
| 69 |
-
gr.Markdown(" #
|
| 70 |
|
| 71 |
with gr.Row():
|
| 72 |
prompt = gr.Text(
|
|
@@ -79,7 +985,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 79 |
|
| 80 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
with gr.Accordion("Advanced Settings", open=False):
|
| 85 |
negative_prompt = gr.Text(
|
|
@@ -99,22 +1011,15 @@ with gr.Blocks(css=css) as demo:
|
|
| 99 |
|
| 100 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
height = gr.Slider(
|
| 112 |
-
label="Height",
|
| 113 |
-
minimum=256,
|
| 114 |
-
maximum=MAX_IMAGE_SIZE,
|
| 115 |
-
step=32,
|
| 116 |
-
value=1024, # Replace with defaults that work for your model
|
| 117 |
-
)
|
| 118 |
|
| 119 |
with gr.Row():
|
| 120 |
guidance_scale = gr.Slider(
|
|
@@ -122,15 +1027,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 122 |
minimum=0.0,
|
| 123 |
maximum=10.0,
|
| 124 |
step=0.1,
|
| 125 |
-
value=
|
| 126 |
)
|
| 127 |
|
| 128 |
-
num_inference_steps = gr.
|
|
|
|
|
|
|
| 129 |
label="Number of inference steps",
|
| 130 |
-
minimum=1,
|
| 131 |
-
maximum=50,
|
| 132 |
-
step=1,
|
| 133 |
-
value=2, # Replace with defaults that work for your model
|
| 134 |
)
|
| 135 |
|
| 136 |
gr.Examples(examples=examples, inputs=[prompt])
|
|
@@ -142,12 +1045,11 @@ with gr.Blocks(css=css) as demo:
|
|
| 142 |
negative_prompt,
|
| 143 |
seed,
|
| 144 |
randomize_seed,
|
| 145 |
-
|
| 146 |
-
height,
|
| 147 |
guidance_scale,
|
| 148 |
num_inference_steps,
|
| 149 |
],
|
| 150 |
-
outputs=[result, seed],
|
| 151 |
)
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import random
|
| 4 |
+
import json
|
| 5 |
+
import spaces #[uncomment to use ZeroGPU]
|
| 6 |
+
from diffusers import (
|
| 7 |
+
AutoencoderKL,
|
| 8 |
+
StableDiffusionXLPipeline,
|
| 9 |
+
)
|
| 10 |
+
from huggingface_hub import login, hf_hub_download
|
| 11 |
+
from PIL import Image
|
| 12 |
+
# from huggingface_hub import login
|
| 13 |
+
from SVDNoiseUnet import NPNet64
|
| 14 |
+
import functools
|
| 15 |
+
import random
|
| 16 |
+
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
|
| 17 |
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torchvision.utils import make_grid
|
| 21 |
+
import time
|
| 22 |
+
from pytorch_lightning import seed_everything
|
| 23 |
+
from torch import autocast
|
| 24 |
+
from contextlib import contextmanager, nullcontext
|
| 25 |
+
import accelerate
|
| 26 |
+
import torchsde
|
| 27 |
+
from SVDNoiseUnet import NPNet128
|
| 28 |
+
from tqdm import tqdm, trange
|
| 29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
+
model_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" # Replace to the model you would like to use
|
| 31 |
|
| 32 |
+
precision_scope = autocast
|
| 33 |
+
|
| 34 |
+
def extract_into_tensor(a, t, x_shape):
|
| 35 |
+
b, *_ = t.shape
|
| 36 |
+
out = a.gather(-1, t)
|
| 37 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 38 |
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
def append_zero(x):
|
| 41 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 42 |
+
|
| 43 |
+
# New helper to load a list-of-dicts preference JSON
|
| 44 |
+
# JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
|
| 45 |
+
def load_preference_json(json_path: str) -> list[dict]:
|
| 46 |
+
"""Load records from a JSON file formatted as a list of preference dicts."""
|
| 47 |
+
with open(json_path, 'r') as f:
|
| 48 |
+
data = json.load(f)
|
| 49 |
+
return data
|
| 50 |
+
|
| 51 |
+
# New helper to extract just the prompts from the preference JSON
|
| 52 |
+
# Returns a flat list of all 'prompt' values
|
| 53 |
+
|
| 54 |
+
def extract_prompts_from_pref_json(json_path: str) -> list[str]:
|
| 55 |
+
"""Load a JSON of preference records and return only the prompts."""
|
| 56 |
+
records = load_preference_json(json_path)
|
| 57 |
+
return [rec['prompt'] for rec in records]
|
| 58 |
+
|
| 59 |
+
# Example usage:
|
| 60 |
+
# prompts = extract_prompts_from_pref_json("path/to/preference.json")
|
| 61 |
+
# print(prompts)
|
| 62 |
+
|
| 63 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
|
| 64 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 65 |
+
ramp = torch.linspace(0, 1, n)
|
| 66 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 67 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 68 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 69 |
+
return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 70 |
+
|
| 71 |
+
def extract_into_tensor(a, t, x_shape):
|
| 72 |
+
b, *_ = t.shape
|
| 73 |
+
out = a.gather(-1, t)
|
| 74 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 75 |
+
|
| 76 |
+
def append_zero(x):
|
| 77 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 78 |
+
|
| 79 |
+
def append_dims(x, target_dims):
|
| 80 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 81 |
+
dims_to_append = target_dims - x.ndim
|
| 82 |
+
if dims_to_append < 0:
|
| 83 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
| 84 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 85 |
+
|
| 86 |
+
class CFGDenoiser(nn.Module):
|
| 87 |
+
def __init__(self, model):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.inner_model = model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def prepare_sdxl_pipeline_step_parameter(self, pipe, prompts, need_cfg, device):
|
| 93 |
+
(
|
| 94 |
+
prompt_embeds,
|
| 95 |
+
negative_prompt_embeds,
|
| 96 |
+
pooled_prompt_embeds,
|
| 97 |
+
negative_pooled_prompt_embeds,
|
| 98 |
+
) = pipe.encode_prompt(
|
| 99 |
+
prompt=prompts,
|
| 100 |
+
device=device,
|
| 101 |
+
do_classifier_free_guidance=need_cfg,
|
| 102 |
+
)
|
| 103 |
+
# timesteps = pipe.scheduler.timesteps
|
| 104 |
+
|
| 105 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 106 |
+
add_text_embeds = pooled_prompt_embeds.to(device)
|
| 107 |
+
original_size = (1024, 1024)
|
| 108 |
+
crops_coords_top_left = (0, 0)
|
| 109 |
+
target_size = (1024, 1024)
|
| 110 |
+
text_encoder_projection_dim = None
|
| 111 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 112 |
+
if pipe.text_encoder_2 is None:
|
| 113 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 114 |
+
else:
|
| 115 |
+
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
| 116 |
+
passed_add_embed_dim = (
|
| 117 |
+
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 118 |
+
)
|
| 119 |
+
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
| 120 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 123 |
+
)
|
| 124 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
| 125 |
+
add_time_ids = add_time_ids.to(device)
|
| 126 |
+
negative_add_time_ids = add_time_ids
|
| 127 |
+
|
| 128 |
+
if need_cfg:
|
| 129 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 130 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 131 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 132 |
+
ret_dict = {
|
| 133 |
+
"text_embeds": add_text_embeds,
|
| 134 |
+
"time_ids": add_time_ids
|
| 135 |
+
}
|
| 136 |
+
return prompt_embeds, ret_dict
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_golden_noised(self, x, sigma,sigma_nxt, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = [],noise_training_list={}):
|
| 140 |
+
x_in = torch.cat([x] * 2)
|
| 141 |
+
sigma_in = torch.cat([sigma] * 2)
|
| 142 |
+
sigma_nxt = torch.cat([sigma_nxt] * 2)
|
| 143 |
+
prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
|
| 144 |
+
_, ret = self.inner_model.get_customed_golden_noise(x_in
|
| 145 |
+
, cond_scale
|
| 146 |
+
, sigma_in, sigma_nxt
|
| 147 |
+
, True
|
| 148 |
+
, noise_training_list=noise_training_list
|
| 149 |
+
, encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
|
| 150 |
+
, added_cond_kwargs=cond_kwargs).chunk(2)
|
| 151 |
+
|
| 152 |
+
return ret
|
| 153 |
+
|
| 154 |
+
def forward(self, x, sigma, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = []):
|
| 155 |
+
prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
|
| 156 |
+
# w = cond_scale * x.new_ones([x.shape[0]])
|
| 157 |
+
# w_embedding = guidance_scale_embedding(w, embedding_dim=self.inner_model.inner_model.config["time_cond_proj_dim"])
|
| 158 |
+
# w_embedding = w_embedding.to(device=x.device, dtype=x.dtype)
|
| 159 |
+
# # t = self.inner_model.sigma_to_t(sigma)
|
| 160 |
+
# cond = self.inner_model(
|
| 161 |
+
# x,
|
| 162 |
+
# sigma,
|
| 163 |
+
# timestep_cond=w_embedding,
|
| 164 |
+
# encoder_hidden_states=cond.to(device=x.device, dtype=x.dtype),
|
| 165 |
+
# )
|
| 166 |
+
# return cond
|
| 167 |
+
x_in = torch.cat([x] * 2)
|
| 168 |
+
sigma_in = torch.cat([sigma] * 2)
|
| 169 |
+
# cond_in = torch.cat([uncond, cond])
|
| 170 |
+
uncond, cond = self.inner_model(x_in
|
| 171 |
+
, sigma_in
|
| 172 |
+
, tmp_list
|
| 173 |
+
, encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
|
| 174 |
+
, added_cond_kwargs=cond_kwargs).chunk(2)
|
| 175 |
+
if need_distill_uncond:
|
| 176 |
+
uncond_list.append(uncond)
|
| 177 |
+
return prompt_embeds, uncond + (cond - uncond) * cond_scale
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class DiscreteSchedule(nn.Module):
|
| 181 |
+
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
| 182 |
+
levels."""
|
| 183 |
+
|
| 184 |
+
def __init__(self, sigmas, quantize):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.register_buffer('sigmas', sigmas)
|
| 187 |
+
self.register_buffer('log_sigmas', sigmas.log())
|
| 188 |
+
self.quantize = quantize
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def sigma_min(self):
|
| 192 |
+
return self.sigmas[0]
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def sigma_max(self):
|
| 196 |
+
return self.sigmas[-1]
|
| 197 |
+
|
| 198 |
+
def get_sigmas(self, n=None):
|
| 199 |
+
if n is None:
|
| 200 |
+
return append_zero(self.sigmas.flip(0))
|
| 201 |
+
t_max = len(self.sigmas) - 1
|
| 202 |
+
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
| 203 |
+
return append_zero(self.t_to_sigma(t))
|
| 204 |
+
|
| 205 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 206 |
+
quantize = self.quantize if quantize is None else quantize
|
| 207 |
+
log_sigma = sigma.log()
|
| 208 |
+
dists = log_sigma - self.log_sigmas[:, None]
|
| 209 |
+
if quantize:
|
| 210 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 211 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
| 212 |
+
high_idx = low_idx + 1
|
| 213 |
+
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
| 214 |
+
w = (low - log_sigma) / (low - high)
|
| 215 |
+
w = w.clamp(0, 1)
|
| 216 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 217 |
+
return t.view(sigma.shape)
|
| 218 |
+
|
| 219 |
+
def t_to_sigma(self, t):
|
| 220 |
+
t = t.float()
|
| 221 |
+
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
| 222 |
+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
| 223 |
+
return log_sigma.exp()
|
| 224 |
+
|
| 225 |
+
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
| 226 |
+
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
| 227 |
+
noise)."""
|
| 228 |
+
|
| 229 |
+
def __init__(self, pipe, alphas_cumprod, quantize = False):
|
| 230 |
+
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
| 231 |
+
self.pipe = pipe
|
| 232 |
+
self.inner_model = pipe.unet
|
| 233 |
+
# self.alphas_cumprod = alphas_cumprod.flip(0)
|
| 234 |
+
# Prepare a reversed version of alphas_cumprod for backward scheduling
|
| 235 |
+
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
| 236 |
+
# self.register_buffer('alphas_cumprod_prev', append_zero(alphas_cumprod[:-1]))
|
| 237 |
+
self.sigma_data = 1.
|
| 238 |
+
|
| 239 |
+
def get_scalings(self, sigma):
|
| 240 |
+
c_out = -sigma
|
| 241 |
+
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 242 |
+
return c_out, c_in
|
| 243 |
+
|
| 244 |
+
def get_eps(self, *args, **kwargs):
|
| 245 |
+
return self.inner_model(*args, **kwargs)
|
| 246 |
+
|
| 247 |
+
def get_alphact_and_sigma(self, timesteps, x_0, noise):
|
| 248 |
+
high_idx = torch.ceil(timesteps).int()
|
| 249 |
+
low_idx = torch.floor(timesteps).int()
|
| 250 |
+
|
| 251 |
+
nxt_ts = timesteps - timesteps.new_ones(timesteps.shape[0])
|
| 252 |
+
|
| 253 |
+
w = (timesteps - low_idx) / (high_idx - low_idx)
|
| 254 |
+
|
| 255 |
+
beta_1 = torch.tensor([1e-4],dtype=torch.float32)
|
| 256 |
+
beta_T = torch.tensor([0.02],dtype=torch.float32)
|
| 257 |
+
ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32)
|
| 258 |
+
|
| 259 |
+
beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
|
| 260 |
+
beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
|
| 261 |
+
|
| 262 |
+
alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
|
| 263 |
+
alpha_t_prev = beta_t.new_ones(beta_t.shape[0]) - beta_t_prev
|
| 264 |
+
|
| 265 |
+
dir_xt = (1. - alpha_t_prev).sqrt() * noise
|
| 266 |
+
x_prev = alpha_t_prev.sqrt() * x_0 + dir_xt + noise
|
| 267 |
+
|
| 268 |
+
alpha_cumprod_t_floor = self.alpha_cumprods[low_idx]
|
| 269 |
+
alpha_cumprod_t = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
|
| 270 |
+
sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
|
| 271 |
+
sigmas = torch.sqrt(alpha_cumprod_t.new_ones(alpha_cumprod_t.shape[0]) - alpha_cumprod_t)
|
| 272 |
+
|
| 273 |
+
# Fix broadcasting
|
| 274 |
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[:, None, None]
|
| 275 |
+
sigmas = sigmas[:, None, None]
|
| 276 |
+
return alpha_cumprod_t, sigmas
|
| 277 |
+
|
| 278 |
+
def get_c_ins(self,sigmas): # use to adjust loss
|
| 279 |
+
ret = list()
|
| 280 |
+
for sigma in sigmas:
|
| 281 |
+
_, c_in = self.get_scalings(sigma=sigma)
|
| 282 |
+
ret.append(c_in)
|
| 283 |
+
return ret
|
| 284 |
+
|
| 285 |
+
# def predicted_origin(model_output, timesteps, sample, alphas, sigmas, prediction_type = "epsilon"):
|
| 286 |
+
# if prediction_type == "epsilon":
|
| 287 |
+
# sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
| 288 |
+
# alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
| 289 |
+
# pred_x_0 = (sample - sigmas * model_output) / alphas
|
| 290 |
+
# elif prediction_type == "v_prediction":
|
| 291 |
+
# sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
| 292 |
+
# alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
| 293 |
+
# pred_x_0 = alphas * sample - sigmas * model_output
|
| 294 |
+
# else:
|
| 295 |
+
# raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
| 296 |
+
# return pred_x_0
|
| 297 |
+
|
| 298 |
+
def get_customed_golden_noise(self
|
| 299 |
+
, input
|
| 300 |
+
, unconditional_guidance_scale:float
|
| 301 |
+
, sigma
|
| 302 |
+
, sigma_nxt
|
| 303 |
+
, need_cond = True
|
| 304 |
+
, noise_training_list = {}
|
| 305 |
+
, **kwargs):
|
| 306 |
+
"""User should ensure the input is a pure noise.
|
| 307 |
+
It's a customed golden noise, not the one purposed in the paper.
|
| 308 |
+
Maybe the one purposed in the paper should be implemented in the future."""
|
| 309 |
+
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 310 |
+
|
| 311 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 312 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 313 |
+
if need_cond:
|
| 314 |
+
_, tmp_img = (input * c_in).chunk(2)
|
| 315 |
+
else :
|
| 316 |
+
tmp_img = input * c_in
|
| 317 |
+
# print(tmp_img.max())
|
| 318 |
+
# tmp_list.append(tmp_img)
|
| 319 |
+
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
|
| 320 |
+
x_0 = input + eps * c_out
|
| 321 |
+
# normal_form_input = input * c_in
|
| 322 |
+
x_0_uncond, x_0 = x_0.chunk(2)
|
| 323 |
+
x_0 = x_0_uncond + unconditional_guidance_scale * (x_0 - x_0_uncond)
|
| 324 |
+
x_0 = torch.cat([x_0] * 2)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
t, t_next = t_fn(sigma), t_fn(sigma_nxt)
|
| 328 |
+
h = t_next - t
|
| 329 |
+
|
| 330 |
+
x = (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim)) * input - append_dims((-h).expm1(),input.ndim) * x_0
|
| 331 |
+
|
| 332 |
+
c_out_2, c_in_2 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma_nxt)]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# e_t_uncond_ret, e_t_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample.chunk(2)
|
| 336 |
+
eps_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample
|
| 337 |
+
org_golden_noise = False
|
| 338 |
+
x_1 = x + eps_ret * c_out_2
|
| 339 |
+
if org_golden_noise:
|
| 340 |
+
ret = (x + append_dims((-h).expm1(),input.ndim) * x_1) / (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim))
|
| 341 |
+
else :
|
| 342 |
+
e_t_uncond_ret, e_t_ret = eps_ret.chunk(2)
|
| 343 |
+
e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
|
| 344 |
+
e_t_ret = torch.cat([e_t_ret] * 2)
|
| 345 |
+
ret = x_0 + e_t_ret * append_dims(sigma,input.ndim)
|
| 346 |
+
|
| 347 |
+
noise_training_list['org_noise'] = input * c_in
|
| 348 |
+
noise_training_list['golden_noise'] = ret * c_in
|
| 349 |
+
# noise_training_list.append(tmp_dict)
|
| 350 |
+
return ret
|
| 351 |
+
|
| 352 |
+
# timesteps = self.sigma_to_t(sigma)
|
| 353 |
+
|
| 354 |
+
# high_idx = torch.ceil(timesteps).int().to(input.device)
|
| 355 |
+
# low_idx = torch.floor(timesteps).int().to(input.device)
|
| 356 |
+
|
| 357 |
+
# nxt_ts = (timesteps - timesteps.new_ones(timesteps.shape[0])).to(input.device)
|
| 358 |
+
|
| 359 |
+
# w = (timesteps - low_idx) / (high_idx - low_idx)
|
| 360 |
+
|
| 361 |
+
# beta_1 = torch.tensor([1e-4],dtype=torch.float32).to(input.device)
|
| 362 |
+
# beta_T = torch.tensor([0.02],dtype=torch.float32).to(input.device)
|
| 363 |
+
# ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32).to(input.device)
|
| 364 |
+
|
| 365 |
+
# beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
|
| 366 |
+
# beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
|
| 367 |
+
|
| 368 |
+
# alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
|
| 369 |
+
# alpha_t = append_dims(alpha_t, e_t.ndim)
|
| 370 |
+
# alpha_t_prev = beta_t_prev.new_ones(beta_t_prev.shape[0]) - beta_t_prev
|
| 371 |
+
# alpha_t_prev = append_dims(alpha_t_prev, e_t.ndim)
|
| 372 |
+
# alpha_cumprod_t_floor = self.alphas_cumprod[low_idx]
|
| 373 |
+
# alpha_cumprod_t_floor = append_dims(alpha_cumprod_t_floor, e_t.ndim)
|
| 374 |
+
# alpha_cumprod_t:torch.Tensor = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
|
| 375 |
+
# alpha_cumprod_t_prev:torch.Tensor = (alpha_cumprod_t_floor * alpha_t_prev) #.unsqueeze(1)
|
| 376 |
+
|
| 377 |
+
# sqrt_one_minus_alphas = (1 - alpha_cumprod_t).sqrt()
|
| 378 |
+
|
| 379 |
+
# dir_xt = (1. - alpha_cumprod_t_prev).sqrt() * e_t
|
| 380 |
+
# x_prev = alpha_cumprod_t_prev.sqrt() * x_0 + dir_xt
|
| 381 |
+
|
| 382 |
+
# e_t_uncond_ret, e_t_ret = self.get_eps(x_prev, nxt_ts, **kwargs).sample.chunk(2)
|
| 383 |
+
# e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
|
| 384 |
+
# e_t_ret = torch.cat([e_t_ret] * 2)
|
| 385 |
+
# x_ret = alpha_t.sqrt() * x_0 + sqrt_one_minus_alphas * e_t_ret
|
| 386 |
+
|
| 387 |
+
# return x_ret
|
| 388 |
+
|
| 389 |
+
def forward(self, input, sigma, tmp_list=[], need_cond = True, **kwargs):
|
| 390 |
+
# c_out_1, c_in_1 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 391 |
+
# if need_cond:
|
| 392 |
+
# tmp_img = input * c_in_1
|
| 393 |
+
# else :
|
| 394 |
+
# tmp_img = input * c_in_1
|
| 395 |
+
# tmp_list.append(tmp_img)
|
| 396 |
+
# timestep = self.sigma_to_t(sigma)
|
| 397 |
+
# eps = self.get_eps(sample = input * c_in_1, timestep = timestep, **kwargs).sample
|
| 398 |
+
# c_skip, c_out = self.scalings_for_boundary_conditions(timestep=self.sigma_to_t(sigma))
|
| 399 |
+
# # return (input + eps * c_out_1) * c_out + input * c_in_1 * c_skip
|
| 400 |
+
# return (input + eps * c_out_1)
|
| 401 |
+
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 402 |
+
if need_cond:
|
| 403 |
+
_, tmp_img = (input * c_in).chunk(2)
|
| 404 |
+
else :
|
| 405 |
+
tmp_img = input * c_in
|
| 406 |
+
# print(tmp_img.max())
|
| 407 |
+
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
|
| 408 |
+
tmp_x0 = input + eps * c_out
|
| 409 |
+
tmp_dict = {'tmp_z': tmp_img, 'tmp_x0': tmp_x0}
|
| 410 |
+
tmp_list.append(tmp_dict)
|
| 411 |
+
return tmp_x0 #input + eps * c_out
|
| 412 |
+
|
| 413 |
+
def get_special_sigmas_with_timesteps(self,timesteps):
|
| 414 |
+
low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
|
| 415 |
+
self.alphas_cumprod = self.alphas_cumprod.to('cpu')
|
| 416 |
+
alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
|
| 417 |
+
return ((1 - alphas) / alphas) ** 0.5
|
| 418 |
+
|
| 419 |
+
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
| 420 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
| 421 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
| 422 |
+
if not eta:
|
| 423 |
+
return sigma_to, 0.
|
| 424 |
+
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
| 425 |
+
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
| 426 |
+
return sigma_down, sigma_up
|
| 427 |
+
|
| 428 |
+
def to_d(x, sigma, denoised):
|
| 429 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
| 430 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
| 431 |
+
|
| 432 |
+
class BatchedBrownianTree:
|
| 433 |
+
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
| 434 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
| 435 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
| 436 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
| 437 |
+
if seed is None:
|
| 438 |
+
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
| 439 |
+
self.batched = True
|
| 440 |
+
try:
|
| 441 |
+
assert len(seed) == x.shape[0]
|
| 442 |
+
w0 = w0[0]
|
| 443 |
+
except TypeError:
|
| 444 |
+
seed = [seed]
|
| 445 |
+
self.batched = False
|
| 446 |
+
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def sort(a, b):
|
| 450 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
| 451 |
+
|
| 452 |
+
def __call__(self, t0, t1):
|
| 453 |
+
t0, t1, sign = self.sort(t0, t1)
|
| 454 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
| 455 |
+
return w if self.batched else w[0]
|
| 456 |
+
|
| 457 |
+
class BrownianTreeNoiseSampler:
|
| 458 |
+
"""A noise sampler backed by a torchsde.BrownianTree.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
| 462 |
+
random samples.
|
| 463 |
+
sigma_min (float): The low end of the valid interval.
|
| 464 |
+
sigma_max (float): The high end of the valid interval.
|
| 465 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
| 466 |
+
supplied instead of a single integer, then the noise sampler will
|
| 467 |
+
use one BrownianTree per batch item, each with its own seed.
|
| 468 |
+
transform (callable): A function that maps sigma to the sampler's
|
| 469 |
+
internal timestep.
|
| 470 |
+
"""
|
| 471 |
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
| 472 |
+
self.transform = transform
|
| 473 |
+
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
| 474 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
| 475 |
+
|
| 476 |
+
def __call__(self, sigma, sigma_next):
|
| 477 |
+
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
| 478 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
| 479 |
+
|
| 480 |
+
@torch.no_grad()
|
| 481 |
+
def sample_euler(model
|
| 482 |
+
, x
|
| 483 |
+
, sigmas
|
| 484 |
+
, extra_args=None
|
| 485 |
+
, callback=None
|
| 486 |
+
, disable=None
|
| 487 |
+
, s_churn=0.
|
| 488 |
+
, s_tmin=0.
|
| 489 |
+
, s_tmax=float('inf')
|
| 490 |
+
, tmp_list=[]
|
| 491 |
+
, uncond_list=[]
|
| 492 |
+
, need_distill_uncond=False
|
| 493 |
+
, start_free_step = 1
|
| 494 |
+
, noise_training_list={}
|
| 495 |
+
, s_noise=1.):
|
| 496 |
+
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
| 497 |
+
extra_args = {} if extra_args is None else extra_args
|
| 498 |
+
s_in = x.new_ones([x.shape[0]])
|
| 499 |
+
intermediates = {'x_inter': [x],'pred_x0': []}
|
| 500 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 501 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 502 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 503 |
+
if i == start_free_step:
|
| 504 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
|
| 505 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
|
| 506 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 507 |
+
eps = torch.randn_like(x) * s_noise
|
| 508 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 509 |
+
if gamma > 0:
|
| 510 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 511 |
+
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 512 |
+
d = to_d(x, sigma_hat, denoised)
|
| 513 |
+
if callback is not None:
|
| 514 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 515 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 516 |
+
# Euler method
|
| 517 |
+
x = x + d * dt
|
| 518 |
+
intermediates['pred_x0'].append(denoised)
|
| 519 |
+
intermediates['x_inter'].append(x)
|
| 520 |
+
return prompt_embeds, intermediates, x
|
| 521 |
+
|
| 522 |
+
@torch.no_grad()
|
| 523 |
+
def sample_heun(model
|
| 524 |
+
, x
|
| 525 |
+
, sigmas
|
| 526 |
+
, extra_args=None
|
| 527 |
+
, callback=None
|
| 528 |
+
, disable=None
|
| 529 |
+
, s_churn=0.
|
| 530 |
+
, s_tmin=0.
|
| 531 |
+
, s_tmax=float('inf')
|
| 532 |
+
, tmp_list=[]
|
| 533 |
+
, uncond_list=[]
|
| 534 |
+
, need_distill_uncond=False
|
| 535 |
+
, noise_training_list={}
|
| 536 |
+
, s_noise=1.):
|
| 537 |
+
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
| 538 |
+
extra_args = {} if extra_args is None else extra_args
|
| 539 |
+
s_in = x.new_ones([x.shape[0]])
|
| 540 |
+
intermediates = {'x_inter': [x],'pred_x0': []}
|
| 541 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 542 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 543 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 544 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 545 |
+
eps = torch.randn_like(x) * s_noise
|
| 546 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 547 |
+
if gamma > 0:
|
| 548 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 549 |
+
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 550 |
+
d = to_d(x, sigma_hat, denoised)
|
| 551 |
+
if callback is not None:
|
| 552 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 553 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 554 |
+
if sigmas[i + 1] == 0:
|
| 555 |
+
# Euler method
|
| 556 |
+
x = x + d * dt
|
| 557 |
+
else:
|
| 558 |
+
# Heun's method
|
| 559 |
+
x_2 = x + d * dt
|
| 560 |
+
_, denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
| 561 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
| 562 |
+
d_prime = (d + d_2) / 2
|
| 563 |
+
x = x + d_prime * dt
|
| 564 |
+
intermediates['pred_x0'].append(denoised_2)
|
| 565 |
+
intermediates['x_inter'].append(x)
|
| 566 |
+
return prompt_embeds, intermediates, x
|
| 567 |
+
|
| 568 |
+
@torch.no_grad()
|
| 569 |
+
def sample_dpmpp_ode(model
|
| 570 |
+
, x
|
| 571 |
+
, sigmas
|
| 572 |
+
, need_golden_noise = False
|
| 573 |
+
, start_free_step = 1
|
| 574 |
+
, extra_args=None, callback=None
|
| 575 |
+
, disable=None,tmp_list=[]
|
| 576 |
+
, need_distill_uncond=False
|
| 577 |
+
, need_raw_noise=False
|
| 578 |
+
, uncond_list=[]
|
| 579 |
+
, noise_training_list={}):
|
| 580 |
+
"""DPM-Solver++."""
|
| 581 |
+
extra_args = {} if extra_args is None else extra_args
|
| 582 |
+
s_in = x.new_ones([x.shape[0]])
|
| 583 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 584 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 585 |
+
old_denoised = None
|
| 586 |
+
if need_raw_noise:
|
| 587 |
+
x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=(sigmas[0] - 0.28) * s_in, noise_training_list=noise_training_list,**extra_args)
|
| 588 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
|
| 589 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
|
| 590 |
+
intermediates = {'x_inter': [x],'pred_x0': []}
|
| 591 |
+
|
| 592 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 593 |
+
if i == start_free_step:
|
| 594 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 595 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 596 |
+
# macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
|
| 597 |
+
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 598 |
+
if callback is not None:
|
| 599 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 600 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 601 |
+
h = t_next - t
|
| 602 |
+
|
| 603 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
| 604 |
+
intermediates['pred_x0'].append(denoised)
|
| 605 |
+
intermediates['x_inter'].append(x)
|
| 606 |
+
|
| 607 |
+
# print(denoised_d.max())
|
| 608 |
+
|
| 609 |
+
# intermediates['noise'].append(denoised_d)
|
| 610 |
+
return prompt_embeds, intermediates,x
|
| 611 |
+
|
| 612 |
+
@torch.no_grad()
|
| 613 |
+
def sample_dpmpp_sde(model
|
| 614 |
+
, x
|
| 615 |
+
, sigmas
|
| 616 |
+
, need_golden_noise = False
|
| 617 |
+
, extra_args=None
|
| 618 |
+
, callback=None
|
| 619 |
+
, tmp_list=[]
|
| 620 |
+
, need_distill_uncond=False
|
| 621 |
+
, uncond_list=[]
|
| 622 |
+
, disable=None, eta=1.
|
| 623 |
+
, s_noise=1.
|
| 624 |
+
, noise_sampler=None
|
| 625 |
+
, r=1 / 2):
|
| 626 |
+
"""DPM-Solver++ (stochastic)."""
|
| 627 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 628 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
| 629 |
+
extra_args = {} if extra_args is None else extra_args
|
| 630 |
+
s_in = x.new_ones([x.shape[0]])
|
| 631 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 632 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 633 |
+
if need_golden_noise:
|
| 634 |
+
x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
|
| 635 |
+
|
| 636 |
+
intermediates = {'x_inter': [x],'pred_x0': []}
|
| 637 |
+
|
| 638 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 639 |
+
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 640 |
+
if callback is not None:
|
| 641 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 642 |
+
if sigmas[i + 1] == 0:
|
| 643 |
+
# Euler method
|
| 644 |
+
d = to_d(x, sigmas[i], denoised)
|
| 645 |
+
dt = sigmas[i + 1] - sigmas[i]
|
| 646 |
+
x = x + d * dt
|
| 647 |
+
intermediates['pred_x0'].append(denoised)
|
| 648 |
+
intermediates['x_inter'].append(x)
|
| 649 |
+
else:
|
| 650 |
+
# DPM-Solver++
|
| 651 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 652 |
+
h = t_next - t
|
| 653 |
+
s = t + h * r
|
| 654 |
+
fac = 1 / (2 * r)
|
| 655 |
+
|
| 656 |
+
# Step 1
|
| 657 |
+
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
| 658 |
+
s_ = t_fn(sd)
|
| 659 |
+
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
| 660 |
+
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
| 661 |
+
prompt_embeds, denoised_2 = model(x_2, sigma_fn(s) * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args) #(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 662 |
+
|
| 663 |
+
# Step 2
|
| 664 |
+
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
| 665 |
+
t_next_ = t_fn(sd)
|
| 666 |
+
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
| 667 |
+
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
| 668 |
+
intermediates['pred_x0'].append(x)
|
| 669 |
+
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
| 670 |
+
intermediates['x_inter'].append(x)
|
| 671 |
+
return prompt_embeds, intermediates,x
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
@torch.no_grad()
|
| 675 |
+
def sample_dpmpp_2m(model
|
| 676 |
+
, x
|
| 677 |
+
, sigmas
|
| 678 |
+
# , need_golden_noise = True
|
| 679 |
+
, extra_args=None
|
| 680 |
+
, callback=None
|
| 681 |
+
, disable=None
|
| 682 |
+
, tmp_list=[]
|
| 683 |
+
, need_distill_uncond=False
|
| 684 |
+
, start_free_step=9
|
| 685 |
+
, uncond_list=[]
|
| 686 |
+
, stop_t = None):
|
| 687 |
+
"""DPM-Solver++(2M)."""
|
| 688 |
+
extra_args = {} if extra_args is None else extra_args
|
| 689 |
+
s_in = x.new_ones([x.shape[0]])
|
| 690 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 691 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 692 |
+
old_denoised = None
|
| 693 |
+
# if need_golden_noise:
|
| 694 |
+
# x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
|
| 695 |
+
intermediates = {'x_inter': [x],'pred_x0': []}
|
| 696 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
|
| 697 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
|
| 698 |
+
|
| 699 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 700 |
+
if i == start_free_step and len(sigmas) > 6:
|
| 701 |
+
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 702 |
+
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 703 |
+
# else:
|
| 704 |
+
# register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
|
| 705 |
+
# register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
|
| 706 |
+
# macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
|
| 707 |
+
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
|
| 708 |
+
if callback is not None:
|
| 709 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 710 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 711 |
+
h = t_next - t
|
| 712 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
| 713 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
| 714 |
+
intermediates['pred_x0'].append(denoised)
|
| 715 |
+
intermediates['x_inter'].append(x)
|
| 716 |
+
else:
|
| 717 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 718 |
+
r = h_last / h
|
| 719 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
| 720 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
| 721 |
+
intermediates['x_inter'].append(x)
|
| 722 |
+
intermediates['pred_x0'].append(denoised)
|
| 723 |
+
# print(denoised_d.max())
|
| 724 |
+
old_denoised = denoised
|
| 725 |
+
if i is not None and i == stop_t:
|
| 726 |
+
return intermediates, x
|
| 727 |
+
# intermediates['noise'].append(denoised_d)
|
| 728 |
+
return prompt_embeds, intermediates,x
|
| 729 |
+
|
| 730 |
+
# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
|
| 731 |
+
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
|
| 732 |
+
captions = []
|
| 733 |
+
for caption in prompt_batch:
|
| 734 |
+
if random.random() < proportion_empty_prompts:
|
| 735 |
+
captions.append("")
|
| 736 |
+
elif isinstance(caption, str):
|
| 737 |
+
captions.append(caption)
|
| 738 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 739 |
+
# take a random caption if there are multiple
|
| 740 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 741 |
+
|
| 742 |
+
with torch.no_grad():
|
| 743 |
+
text_inputs = tokenizer(
|
| 744 |
+
captions,
|
| 745 |
+
padding="max_length",
|
| 746 |
+
max_length=tokenizer.model_max_length,
|
| 747 |
+
truncation=True,
|
| 748 |
+
return_tensors="pt",
|
| 749 |
+
)
|
| 750 |
+
text_input_ids = text_inputs.input_ids
|
| 751 |
+
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
|
| 752 |
+
|
| 753 |
+
return prompt_embeds
|
| 754 |
+
|
| 755 |
+
def chunk(it, size):
|
| 756 |
+
it = iter(it)
|
| 757 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
torch_dtype = torch.float32
|
| 761 |
+
device = "cuda"
|
| 762 |
+
# pipe = StableDiffusionPipeline.from_single_file( "./counterfeit/Counterfeit-V3.0_fp32.safetensors")
|
| 763 |
+
repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
|
| 764 |
+
filename = "sdxl_vae.safetensors" # e.g., "pytorch_model.bin"
|
| 765 |
+
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename,cache_dir=".")
|
| 766 |
+
|
| 767 |
+
# pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4')
|
| 768 |
+
vae = AutoencoderKL.from_single_file(downloaded_path, torch_dtype=torch_dtype)
|
| 769 |
+
vae.to('cuda')
|
| 770 |
+
|
| 771 |
+
pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0",torch_dtype=torch_dtype,vae=vae)
|
| 772 |
+
# pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,vae=vae)
|
| 773 |
+
|
| 774 |
+
pipe.to('cuda')
|
| 775 |
+
npn_net = NPNet128('SDXL', './sdxl.pth')
|
| 776 |
+
|
| 777 |
+
pipe = pipe.to(device)
|
| 778 |
+
register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 779 |
+
register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 780 |
MAX_SEED = np.iinfo(np.int32).max
|
| 781 |
MAX_IMAGE_SIZE = 1024
|
| 782 |
+
noise_scheduler = pipe.scheduler
|
| 783 |
+
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=torch_dtype)
|
| 784 |
+
model_wrap = DiscreteEpsDDPMDenoiser(pipe.unet, alpha_schedule, quantize=False)
|
| 785 |
+
accelerator = accelerate.Accelerator()
|
| 786 |
+
|
| 787 |
+
def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
|
| 788 |
+
"""Helper function to generate image with specific number of steps"""
|
| 789 |
+
prompts = [prompt]
|
| 790 |
+
if num_inference_steps <= 10:
|
| 791 |
+
register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 792 |
+
register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 793 |
+
else:
|
| 794 |
+
register_free_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
|
| 795 |
+
register_free_crossattn_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
|
| 796 |
+
if randomize_seed:
|
| 797 |
+
seed = random.randint(0, MAX_SEED)
|
| 798 |
+
|
| 799 |
+
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
| 800 |
+
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
| 801 |
+
return {"prompt_embeds": prompt_embeds}
|
| 802 |
+
|
| 803 |
+
compute_embeddings_fn = functools.partial(
|
| 804 |
+
compute_embeddings,
|
| 805 |
+
proportion_empty_prompts=0,
|
| 806 |
+
text_encoder=pipe.text_encoder,
|
| 807 |
+
tokenizer=pipe.tokenizer,
|
| 808 |
+
)
|
| 809 |
+
generator = torch.Generator().manual_seed(seed)
|
| 810 |
+
|
| 811 |
+
intermediate_photos = list()
|
| 812 |
+
# prompts = prompts[0]
|
| 813 |
+
|
| 814 |
+
# if isinstance(prompts, tuple) or isinstance(prompts, str):
|
| 815 |
+
# prompts = list(prompts)
|
| 816 |
+
if isinstance(prompts, str):
|
| 817 |
+
prompts = prompts #+ 'high quality, best quality, masterpiece, 4K, highres, extremely detailed, ultra-detailed'
|
| 818 |
+
prompts = (prompts,)
|
| 819 |
+
if isinstance(prompts, tuple) or isinstance(prompts, str):
|
| 820 |
+
prompts = list(prompts)
|
| 821 |
+
|
| 822 |
+
shape = [4, height // 8, width // 8]
|
| 823 |
+
start_free_step = num_inference_steps
|
| 824 |
+
fir_stage_sigmas_ct = None
|
| 825 |
+
sec_stage_sigmas_ct = None
|
| 826 |
+
# sigmas = model_wrap.get_sigmas(opt.ddim_steps).to(device=device)
|
| 827 |
+
if num_inference_steps == 5:
|
| 828 |
+
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
|
| 829 |
+
sigmas = get_sigmas_karras(8, sigma_min, sigma_max, rho=5.0, device=device)# 6.0 if 5 else 10 10.0
|
| 830 |
+
|
| 831 |
+
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
|
| 832 |
+
# sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
|
| 833 |
+
ct = get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 834 |
+
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 835 |
+
elif num_inference_steps == 6:
|
| 836 |
+
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
|
| 837 |
+
sigmas = get_sigmas_karras(8, sigma_min, sigma_max,rho=5.0, device=device)# 6.0 if 5 else 10.0
|
| 838 |
+
|
| 839 |
+
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
|
| 840 |
+
# sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
|
| 841 |
+
ct = get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 842 |
+
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 843 |
+
start_free_step = 6
|
| 844 |
+
fir_stage_sigmas_ct = sigmas_ct[:-2]
|
| 845 |
+
sec_stage_sigmas_ct = sigmas_ct[-3:]
|
| 846 |
+
elif num_inference_steps == 8:
|
| 847 |
+
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
|
| 848 |
+
|
| 849 |
+
sigmas = get_sigmas_karras(12, sigma_min, sigma_max,rho=12.0, device=device)# 6.0 if 5 else 10.0
|
| 850 |
+
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[10])
|
| 851 |
+
|
| 852 |
+
ct = get_sigmas_karras(num_inference_steps + 1, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 853 |
+
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 854 |
+
start_free_step = 8
|
| 855 |
+
else:
|
| 856 |
+
image = pipe(prompt=prompts
|
| 857 |
+
,num_inference_steps=num_inference_steps
|
| 858 |
+
,guidance_scale=guidance_scale
|
| 859 |
+
,height=height
|
| 860 |
+
,width=width).images[0]
|
| 861 |
+
return image
|
| 862 |
+
ts = []
|
| 863 |
+
for sigma in sigmas_ct:
|
| 864 |
+
t = model_wrap.sigma_to_t(sigma)
|
| 865 |
+
ts.append(t)
|
| 866 |
+
|
| 867 |
+
c_in = model_wrap.get_c_ins(sigmas=sigmas_ct)
|
| 868 |
+
x = torch.randn([1, *shape], device=device) * sigmas_ct[0]
|
| 869 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
| 870 |
+
(
|
| 871 |
+
c,
|
| 872 |
+
uc,
|
| 873 |
+
_,
|
| 874 |
+
_,
|
| 875 |
+
) = pipe.encode_prompt(
|
| 876 |
+
prompt=prompts,
|
| 877 |
+
device=device,
|
| 878 |
+
do_classifier_free_guidance=True,
|
| 879 |
+
)
|
| 880 |
+
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 881 |
+
|
| 882 |
+
# if (num_inference_steps != -1 or num_inference_steps <= 8) and not opt.force_not_use_NPNet:
|
| 883 |
+
x = npn_net(x,c)
|
| 884 |
+
|
| 885 |
+
extra_args = {'prompt': prompts, 'cond_scale': guidance_scale}
|
| 886 |
+
with torch.no_grad():
|
| 887 |
+
# with precision_scope("cuda" if torch.cuda.is_available() else "cpu"):
|
| 888 |
+
if not (num_inference_steps == 8 or num_inference_steps == 7):
|
| 889 |
+
prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_ode(model_wrap_cfg
|
| 890 |
+
, x
|
| 891 |
+
, fir_stage_sigmas_ct
|
| 892 |
+
, extra_args=extra_args
|
| 893 |
+
, disable=not accelerator.is_main_process
|
| 894 |
+
, need_raw_noise = False
|
| 895 |
+
, tmp_list=intermediate_photos)
|
| 896 |
+
_, _, samples_ddim = sample_euler(model_wrap_cfg
|
| 897 |
+
, samples_ddim
|
| 898 |
+
, sec_stage_sigmas_ct
|
| 899 |
+
, extra_args=extra_args
|
| 900 |
+
, disable=not accelerator.is_main_process
|
| 901 |
+
, s_noise = 0.3
|
| 902 |
+
, tmp_list=intermediate_photos)
|
| 903 |
+
else:
|
| 904 |
+
prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_2m(model_wrap_cfg
|
| 905 |
+
, x
|
| 906 |
+
, sigmas_ct
|
| 907 |
+
, extra_args=extra_args
|
| 908 |
+
, start_free_step=start_free_step
|
| 909 |
+
, disable=not accelerator.is_main_process
|
| 910 |
+
, tmp_list=intermediate_photos)
|
| 911 |
+
# print('2m')
|
| 912 |
+
|
| 913 |
+
x_samples_ddim = pipe.vae.decode(samples_ddim / pipe.vae.config.scaling_factor).sample
|
| 914 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 915 |
|
| 916 |
+
if True: # not opt.skip_save:
|
| 917 |
+
for x_sample in x_samples_ddim:
|
| 918 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
| 919 |
+
image = Image.fromarray(x_sample.astype(np.uint8))
|
| 920 |
+
# base_count += 1
|
| 921 |
+
|
| 922 |
+
# image = pipe(
|
| 923 |
+
# prompt=prompt,
|
| 924 |
+
# negative_prompt=negative_prompt,
|
| 925 |
+
# guidance_scale=guidance_scale,
|
| 926 |
+
# num_inference_steps=num_inference_steps,
|
| 927 |
+
# width=width,
|
| 928 |
+
# height=height,
|
| 929 |
+
# generator=generator,
|
| 930 |
+
# ).images[0]
|
| 931 |
|
| 932 |
+
return image
|
| 933 |
+
|
| 934 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
| 935 |
def infer(
|
| 936 |
prompt,
|
| 937 |
negative_prompt,
|
| 938 |
seed,
|
| 939 |
randomize_seed,
|
| 940 |
+
resolution,
|
|
|
|
| 941 |
guidance_scale,
|
| 942 |
num_inference_steps,
|
| 943 |
progress=gr.Progress(track_tqdm=True),
|
| 944 |
):
|
| 945 |
if randomize_seed:
|
| 946 |
seed = random.randint(0, MAX_SEED)
|
| 947 |
+
|
| 948 |
+
# Parse resolution string into width and height
|
| 949 |
+
width, height = map(int, resolution.split('x'))
|
| 950 |
+
|
| 951 |
+
# Generate image with selected steps
|
| 952 |
+
image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
|
| 953 |
+
|
| 954 |
+
# Generate image with 50 steps for high quality
|
| 955 |
+
image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
|
| 956 |
|
| 957 |
+
return image_quick, image_50_steps, seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
|
| 960 |
examples = [
|
| 961 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 962 |
+
"a painting of a virus monster playing guitar",
|
| 963 |
+
"a painting of a squirrel eating a burger",
|
| 964 |
]
|
| 965 |
|
| 966 |
css = """
|
|
|
|
| 972 |
|
| 973 |
with gr.Blocks(css=css) as demo:
|
| 974 |
with gr.Column(elem_id="col-container"):
|
| 975 |
+
gr.Markdown(" # Hyperparameters are all you need")
|
| 976 |
|
| 977 |
with gr.Row():
|
| 978 |
prompt = gr.Text(
|
|
|
|
| 985 |
|
| 986 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 987 |
|
| 988 |
+
with gr.Row():
|
| 989 |
+
with gr.Column():
|
| 990 |
+
gr.Markdown("### Our fast inference Result")
|
| 991 |
+
result = gr.Image(label="Quick Result", show_label=False)
|
| 992 |
+
with gr.Column():
|
| 993 |
+
gr.Markdown("### Original 50 steps Result")
|
| 994 |
+
result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
|
| 995 |
|
| 996 |
with gr.Accordion("Advanced Settings", open=False):
|
| 997 |
negative_prompt = gr.Text(
|
|
|
|
| 1011 |
|
| 1012 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 1013 |
|
| 1014 |
+
resolution = gr.Dropdown(
|
| 1015 |
+
choices=[
|
| 1016 |
+
"1024x1024",
|
| 1017 |
+
"1216x832",
|
| 1018 |
+
"832x1216"
|
| 1019 |
+
],
|
| 1020 |
+
value="1024x1024",
|
| 1021 |
+
label="Resolution",
|
| 1022 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
|
| 1024 |
with gr.Row():
|
| 1025 |
guidance_scale = gr.Slider(
|
|
|
|
| 1027 |
minimum=0.0,
|
| 1028 |
maximum=10.0,
|
| 1029 |
step=0.1,
|
| 1030 |
+
value=7.5, # Replace with defaults that work for your model
|
| 1031 |
)
|
| 1032 |
|
| 1033 |
+
num_inference_steps = gr.Dropdown(
|
| 1034 |
+
choices=[6, 8],
|
| 1035 |
+
value=8,
|
| 1036 |
label="Number of inference steps",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
)
|
| 1038 |
|
| 1039 |
gr.Examples(examples=examples, inputs=[prompt])
|
|
|
|
| 1045 |
negative_prompt,
|
| 1046 |
seed,
|
| 1047 |
randomize_seed,
|
| 1048 |
+
resolution,
|
|
|
|
| 1049 |
guidance_scale,
|
| 1050 |
num_inference_steps,
|
| 1051 |
],
|
| 1052 |
+
outputs=[result, result_50_steps, seed],
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
if __name__ == "__main__":
|
free_lunch_utils.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.fft as fft
|
| 3 |
+
from diffusers.utils import is_torch_version
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def isinstance_str(x: object, cls_name: str):
|
| 8 |
+
"""
|
| 9 |
+
Checks whether x has any class *named* cls_name in its ancestry.
|
| 10 |
+
Doesn't require access to the class's implementation.
|
| 11 |
+
|
| 12 |
+
Useful for patching!
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
for _cls in x.__class__.__mro__:
|
| 16 |
+
if _cls.__name__ == cls_name:
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def Fourier_filter(x, threshold, scale):
|
| 23 |
+
dtype = x.dtype
|
| 24 |
+
x = x.type(torch.float32)
|
| 25 |
+
# FFT
|
| 26 |
+
x_freq = fft.fftn(x, dim=(-2, -1))
|
| 27 |
+
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
| 28 |
+
|
| 29 |
+
B, C, H, W = x_freq.shape
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
mask = torch.ones((B, C, H, W)).to(device=device)
|
| 32 |
+
|
| 33 |
+
crow, ccol = H // 2, W //2
|
| 34 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
| 35 |
+
x_freq = x_freq * mask
|
| 36 |
+
|
| 37 |
+
# IFFT
|
| 38 |
+
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
| 39 |
+
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
| 40 |
+
|
| 41 |
+
x_filtered = x_filtered.type(dtype)
|
| 42 |
+
return x_filtered
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def register_upblock2d(model):
|
| 46 |
+
def up_forward(self):
|
| 47 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 48 |
+
for resnet in self.resnets:
|
| 49 |
+
# pop res hidden states
|
| 50 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 51 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 52 |
+
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
|
| 53 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 54 |
+
|
| 55 |
+
if self.training and self.gradient_checkpointing:
|
| 56 |
+
|
| 57 |
+
def create_custom_forward(module):
|
| 58 |
+
def custom_forward(*inputs):
|
| 59 |
+
return module(*inputs)
|
| 60 |
+
|
| 61 |
+
return custom_forward
|
| 62 |
+
|
| 63 |
+
if is_torch_version(">=", "1.11.0"):
|
| 64 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 65 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 69 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
hidden_states = resnet(hidden_states, temb)
|
| 73 |
+
|
| 74 |
+
if self.upsamplers is not None:
|
| 75 |
+
for upsampler in self.upsamplers:
|
| 76 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 77 |
+
|
| 78 |
+
return hidden_states
|
| 79 |
+
|
| 80 |
+
return forward
|
| 81 |
+
|
| 82 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 83 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 84 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 88 |
+
def up_forward(self):
|
| 89 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 90 |
+
for resnet in self.resnets:
|
| 91 |
+
# pop res hidden states
|
| 92 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 93 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 94 |
+
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
| 95 |
+
|
| 96 |
+
# --------------- FreeU code -----------------------
|
| 97 |
+
# Only operate on the first two stages
|
| 98 |
+
if hidden_states.shape[1] == 1280:
|
| 99 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 100 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 101 |
+
if hidden_states.shape[1] == 640:
|
| 102 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 103 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 104 |
+
# ---------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 107 |
+
|
| 108 |
+
if self.training and self.gradient_checkpointing:
|
| 109 |
+
|
| 110 |
+
def create_custom_forward(module):
|
| 111 |
+
def custom_forward(*inputs):
|
| 112 |
+
return module(*inputs)
|
| 113 |
+
|
| 114 |
+
return custom_forward
|
| 115 |
+
|
| 116 |
+
if is_torch_version(">=", "1.11.0"):
|
| 117 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 118 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 122 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
hidden_states = resnet(hidden_states, temb)
|
| 126 |
+
|
| 127 |
+
if self.upsamplers is not None:
|
| 128 |
+
for upsampler in self.upsamplers:
|
| 129 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 130 |
+
|
| 131 |
+
return hidden_states
|
| 132 |
+
|
| 133 |
+
return forward
|
| 134 |
+
|
| 135 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 136 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 137 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 138 |
+
setattr(upsample_block, 'b1', b1)
|
| 139 |
+
setattr(upsample_block, 'b2', b2)
|
| 140 |
+
setattr(upsample_block, 's1', s1)
|
| 141 |
+
setattr(upsample_block, 's2', s2)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def register_crossattn_upblock2d(model):
|
| 145 |
+
def up_forward(self):
|
| 146 |
+
def forward(
|
| 147 |
+
hidden_states: torch.FloatTensor,
|
| 148 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 149 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 150 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 151 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 152 |
+
upsample_size: Optional[int] = None,
|
| 153 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 154 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 155 |
+
):
|
| 156 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 157 |
+
# pop res hidden states
|
| 158 |
+
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 159 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 160 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 161 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 162 |
+
|
| 163 |
+
if self.training and self.gradient_checkpointing:
|
| 164 |
+
|
| 165 |
+
def create_custom_forward(module, return_dict=None):
|
| 166 |
+
def custom_forward(*inputs):
|
| 167 |
+
if return_dict is not None:
|
| 168 |
+
return module(*inputs, return_dict=return_dict)
|
| 169 |
+
else:
|
| 170 |
+
return module(*inputs)
|
| 171 |
+
|
| 172 |
+
return custom_forward
|
| 173 |
+
|
| 174 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 175 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 176 |
+
create_custom_forward(resnet),
|
| 177 |
+
hidden_states,
|
| 178 |
+
temb,
|
| 179 |
+
**ckpt_kwargs,
|
| 180 |
+
)
|
| 181 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 182 |
+
create_custom_forward(attn, return_dict=False),
|
| 183 |
+
hidden_states,
|
| 184 |
+
encoder_hidden_states,
|
| 185 |
+
None, # timestep
|
| 186 |
+
None, # class_labels
|
| 187 |
+
cross_attention_kwargs,
|
| 188 |
+
attention_mask,
|
| 189 |
+
encoder_attention_mask,
|
| 190 |
+
**ckpt_kwargs,
|
| 191 |
+
)[0]
|
| 192 |
+
else:
|
| 193 |
+
hidden_states = resnet(hidden_states, temb)
|
| 194 |
+
hidden_states = attn(
|
| 195 |
+
hidden_states,
|
| 196 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 197 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 198 |
+
attention_mask=attention_mask,
|
| 199 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 200 |
+
return_dict=False,
|
| 201 |
+
)[0]
|
| 202 |
+
|
| 203 |
+
if self.upsamplers is not None:
|
| 204 |
+
for upsampler in self.upsamplers:
|
| 205 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 206 |
+
|
| 207 |
+
return hidden_states
|
| 208 |
+
|
| 209 |
+
return forward
|
| 210 |
+
|
| 211 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 212 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 213 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 217 |
+
def up_forward(self):
|
| 218 |
+
def forward(
|
| 219 |
+
hidden_states: torch.FloatTensor,
|
| 220 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 221 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 222 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 223 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 224 |
+
upsample_size: Optional[int] = None,
|
| 225 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 226 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 227 |
+
):
|
| 228 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 229 |
+
# pop res hidden states
|
| 230 |
+
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 231 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 232 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 233 |
+
|
| 234 |
+
# --------------- FreeU code -----------------------
|
| 235 |
+
# Only operate on the first two stages
|
| 236 |
+
if hidden_states.shape[1] == 1280:
|
| 237 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 238 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 239 |
+
if hidden_states.shape[1] == 640:
|
| 240 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 241 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 242 |
+
# ---------------------------------------------------------
|
| 243 |
+
|
| 244 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 245 |
+
|
| 246 |
+
if self.training and self.gradient_checkpointing:
|
| 247 |
+
|
| 248 |
+
def create_custom_forward(module, return_dict=None):
|
| 249 |
+
def custom_forward(*inputs):
|
| 250 |
+
if return_dict is not None:
|
| 251 |
+
return module(*inputs, return_dict=return_dict)
|
| 252 |
+
else:
|
| 253 |
+
return module(*inputs)
|
| 254 |
+
|
| 255 |
+
return custom_forward
|
| 256 |
+
|
| 257 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 258 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 259 |
+
create_custom_forward(resnet),
|
| 260 |
+
hidden_states,
|
| 261 |
+
temb,
|
| 262 |
+
**ckpt_kwargs,
|
| 263 |
+
)
|
| 264 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 265 |
+
create_custom_forward(attn, return_dict=False),
|
| 266 |
+
hidden_states,
|
| 267 |
+
encoder_hidden_states,
|
| 268 |
+
None, # timestep
|
| 269 |
+
None, # class_labels
|
| 270 |
+
cross_attention_kwargs,
|
| 271 |
+
attention_mask,
|
| 272 |
+
encoder_attention_mask,
|
| 273 |
+
**ckpt_kwargs,
|
| 274 |
+
)[0]
|
| 275 |
+
else:
|
| 276 |
+
hidden_states = resnet(hidden_states, temb)
|
| 277 |
+
# hidden_states = attn(
|
| 278 |
+
# hidden_states,
|
| 279 |
+
# encoder_hidden_states=encoder_hidden_states,
|
| 280 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
| 281 |
+
# encoder_attention_mask=encoder_attention_mask,
|
| 282 |
+
# return_dict=False,
|
| 283 |
+
# )[0]
|
| 284 |
+
hidden_states = attn(
|
| 285 |
+
hidden_states,
|
| 286 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 287 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 288 |
+
)[0]
|
| 289 |
+
|
| 290 |
+
if self.upsamplers is not None:
|
| 291 |
+
for upsampler in self.upsamplers:
|
| 292 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 293 |
+
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
return forward
|
| 297 |
+
|
| 298 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 299 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 300 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 301 |
+
setattr(upsample_block, 'b1', b1)
|
| 302 |
+
setattr(upsample_block, 'b2', b2)
|
| 303 |
+
setattr(upsample_block, 's1', s1)
|
| 304 |
+
setattr(upsample_block, 's2', s2)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
einops
|
| 3 |
+
pytorch_lightning
|
| 4 |
+
accelerate>=0.20.0
|
| 5 |
+
torchsde
|
| 6 |
+
pycocotools
|
| 7 |
+
diffusers== 0.32.2
|
| 8 |
+
timm
|
| 9 |
+
transformers==4.49
|
| 10 |
+
torch>=2.0.0
|
sdxl.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a8629a4d939f9ed8f02ed2ad39b8317b701fb9e59d175ce186512e4a2687e48
|
| 3 |
+
size 121965599
|