coralLight's picture
close npnet
ed22c52
import gradio as gr
import numpy as np
import random
import json
import spaces #[uncomment to use ZeroGPU]
from diffusers import (
AutoencoderKL,
StableDiffusionXLPipeline,
)
from huggingface_hub import login, hf_hub_download
from PIL import Image
# from huggingface_hub import login
from SVDNoiseUnet import NPNet64
import functools
import random
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
import torch
import torch.nn as nn
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import torchsde
from SVDNoiseUnet import NPNet128
from tqdm import tqdm, trange
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "Lykon/dreamshaper-xl-1-0" # Replace to the model you would like to use
precision_scope = autocast
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
# New helper to load a list-of-dicts preference JSON
# JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
def load_preference_json(json_path: str) -> list[dict]:
"""Load records from a JSON file formatted as a list of preference dicts."""
with open(json_path, 'r') as f:
data = json.load(f)
return data
# New helper to extract just the prompts from the preference JSON
# Returns a flat list of all 'prompt' values
def extract_prompts_from_pref_json(json_path: str) -> list[str]:
"""Load a JSON of preference records and return only the prompts."""
records = load_preference_json(json_path)
return [rec['prompt'] for rec in records]
# Example usage:
# prompts = extract_prompts_from_pref_json("path/to/preference.json")
# print(prompts)
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def prepare_sdxl_pipeline_step_parameter(self, pipe, prompts, need_cfg, device):
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt=prompts,
device=device,
do_classifier_free_guidance=need_cfg,
)
# timesteps = pipe.scheduler.timesteps
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = pooled_prompt_embeds.to(device)
original_size = (1024, 1024)
crops_coords_top_left = (0, 0)
target_size = (1024, 1024)
text_encoder_projection_dim = None
add_time_ids = list(original_size + crops_coords_top_left + target_size)
if pipe.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
passed_add_embed_dim = (
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
add_time_ids = add_time_ids.to(device)
negative_add_time_ids = add_time_ids
if need_cfg:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
ret_dict = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids
}
return prompt_embeds, ret_dict
def get_golden_noised(self, x, sigma,sigma_nxt, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = [],noise_training_list={}):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
sigma_nxt = torch.cat([sigma_nxt] * 2)
prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
_, ret = self.inner_model.get_customed_golden_noise(x_in
, cond_scale
, sigma_in, sigma_nxt
, True
, noise_training_list=noise_training_list
, encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
, added_cond_kwargs=cond_kwargs).chunk(2)
return ret
def forward(self, x, sigma, prompt, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = []):
prompt_embeds, cond_kwargs = self.prepare_sdxl_pipeline_step_parameter(self.inner_model.pipe, prompt, need_cfg=True, device=self.inner_model.pipe.device)
# w = cond_scale * x.new_ones([x.shape[0]])
# w_embedding = guidance_scale_embedding(w, embedding_dim=self.inner_model.inner_model.config["time_cond_proj_dim"])
# w_embedding = w_embedding.to(device=x.device, dtype=x.dtype)
# # t = self.inner_model.sigma_to_t(sigma)
# cond = self.inner_model(
# x,
# sigma,
# timestep_cond=w_embedding,
# encoder_hidden_states=cond.to(device=x.device, dtype=x.dtype),
# )
# return cond
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
# cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in
, sigma_in
, tmp_list
, encoder_hidden_states=prompt_embeds.to(device=x.device, dtype=x.dtype)
, added_cond_kwargs=cond_kwargs).chunk(2)
if need_distill_uncond:
uncond_list.append(uncond)
return prompt_embeds, uncond + (cond - uncond) * cond_scale
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, pipe, alphas_cumprod, quantize = False):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.pipe = pipe
self.inner_model = pipe.unet
# self.alphas_cumprod = alphas_cumprod.flip(0)
# Prepare a reversed version of alphas_cumprod for backward scheduling
self.register_buffer('alphas_cumprod', alphas_cumprod)
# self.register_buffer('alphas_cumprod_prev', append_zero(alphas_cumprod[:-1]))
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def get_alphact_and_sigma(self, timesteps, x_0, noise):
high_idx = torch.ceil(timesteps).int()
low_idx = torch.floor(timesteps).int()
nxt_ts = timesteps - timesteps.new_ones(timesteps.shape[0])
w = (timesteps - low_idx) / (high_idx - low_idx)
beta_1 = torch.tensor([1e-4],dtype=torch.float32)
beta_T = torch.tensor([0.02],dtype=torch.float32)
ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32)
beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
alpha_t_prev = beta_t.new_ones(beta_t.shape[0]) - beta_t_prev
dir_xt = (1. - alpha_t_prev).sqrt() * noise
x_prev = alpha_t_prev.sqrt() * x_0 + dir_xt + noise
alpha_cumprod_t_floor = self.alpha_cumprods[low_idx]
alpha_cumprod_t = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
sigmas = torch.sqrt(alpha_cumprod_t.new_ones(alpha_cumprod_t.shape[0]) - alpha_cumprod_t)
# Fix broadcasting
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[:, None, None]
sigmas = sigmas[:, None, None]
return alpha_cumprod_t, sigmas
def get_c_ins(self,sigmas): # use to adjust loss
ret = list()
for sigma in sigmas:
_, c_in = self.get_scalings(sigma=sigma)
ret.append(c_in)
return ret
# def predicted_origin(model_output, timesteps, sample, alphas, sigmas, prediction_type = "epsilon"):
# if prediction_type == "epsilon":
# sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
# alphas = extract_into_tensor(alphas, timesteps, sample.shape)
# pred_x_0 = (sample - sigmas * model_output) / alphas
# elif prediction_type == "v_prediction":
# sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
# alphas = extract_into_tensor(alphas, timesteps, sample.shape)
# pred_x_0 = alphas * sample - sigmas * model_output
# else:
# raise ValueError(f"Prediction type {prediction_type} currently not supported.")
# return pred_x_0
def get_customed_golden_noise(self
, input
, unconditional_guidance_scale:float
, sigma
, sigma_nxt
, need_cond = True
, noise_training_list = {}
, **kwargs):
"""User should ensure the input is a pure noise.
It's a customed golden noise, not the one purposed in the paper.
Maybe the one purposed in the paper should be implemented in the future."""
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
if need_cond:
_, tmp_img = (input * c_in).chunk(2)
else :
tmp_img = input * c_in
# print(tmp_img.max())
# tmp_list.append(tmp_img)
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
x_0 = input + eps * c_out
# normal_form_input = input * c_in
x_0_uncond, x_0 = x_0.chunk(2)
x_0 = x_0_uncond + unconditional_guidance_scale * (x_0 - x_0_uncond)
x_0 = torch.cat([x_0] * 2)
t, t_next = t_fn(sigma), t_fn(sigma_nxt)
h = t_next - t
x = (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim)) * input - append_dims((-h).expm1(),input.ndim) * x_0
c_out_2, c_in_2 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma_nxt)]
# e_t_uncond_ret, e_t_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample.chunk(2)
eps_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample
org_golden_noise = False
x_1 = x + eps_ret * c_out_2
if org_golden_noise:
ret = (x + append_dims((-h).expm1(),input.ndim) * x_1) / (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim))
else :
e_t_uncond_ret, e_t_ret = eps_ret.chunk(2)
e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
e_t_ret = torch.cat([e_t_ret] * 2)
ret = x_0 + e_t_ret * append_dims(sigma,input.ndim)
noise_training_list['org_noise'] = input * c_in
noise_training_list['golden_noise'] = ret * c_in
# noise_training_list.append(tmp_dict)
return ret
# timesteps = self.sigma_to_t(sigma)
# high_idx = torch.ceil(timesteps).int().to(input.device)
# low_idx = torch.floor(timesteps).int().to(input.device)
# nxt_ts = (timesteps - timesteps.new_ones(timesteps.shape[0])).to(input.device)
# w = (timesteps - low_idx) / (high_idx - low_idx)
# beta_1 = torch.tensor([1e-4],dtype=torch.float32).to(input.device)
# beta_T = torch.tensor([0.02],dtype=torch.float32).to(input.device)
# ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32).to(input.device)
# beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
# beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
# alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
# alpha_t = append_dims(alpha_t, e_t.ndim)
# alpha_t_prev = beta_t_prev.new_ones(beta_t_prev.shape[0]) - beta_t_prev
# alpha_t_prev = append_dims(alpha_t_prev, e_t.ndim)
# alpha_cumprod_t_floor = self.alphas_cumprod[low_idx]
# alpha_cumprod_t_floor = append_dims(alpha_cumprod_t_floor, e_t.ndim)
# alpha_cumprod_t:torch.Tensor = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
# alpha_cumprod_t_prev:torch.Tensor = (alpha_cumprod_t_floor * alpha_t_prev) #.unsqueeze(1)
# sqrt_one_minus_alphas = (1 - alpha_cumprod_t).sqrt()
# dir_xt = (1. - alpha_cumprod_t_prev).sqrt() * e_t
# x_prev = alpha_cumprod_t_prev.sqrt() * x_0 + dir_xt
# e_t_uncond_ret, e_t_ret = self.get_eps(x_prev, nxt_ts, **kwargs).sample.chunk(2)
# e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
# e_t_ret = torch.cat([e_t_ret] * 2)
# x_ret = alpha_t.sqrt() * x_0 + sqrt_one_minus_alphas * e_t_ret
# return x_ret
def forward(self, input, sigma, tmp_list=[], need_cond = True, **kwargs):
# c_out_1, c_in_1 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
# if need_cond:
# tmp_img = input * c_in_1
# else :
# tmp_img = input * c_in_1
# tmp_list.append(tmp_img)
# timestep = self.sigma_to_t(sigma)
# eps = self.get_eps(sample = input * c_in_1, timestep = timestep, **kwargs).sample
# c_skip, c_out = self.scalings_for_boundary_conditions(timestep=self.sigma_to_t(sigma))
# # return (input + eps * c_out_1) * c_out + input * c_in_1 * c_skip
# return (input + eps * c_out_1)
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
if need_cond:
_, tmp_img = (input * c_in).chunk(2)
else :
tmp_img = input * c_in
# print(tmp_img.max())
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
tmp_x0 = input + eps * c_out
tmp_dict = {'tmp_z': tmp_img, 'tmp_x0': tmp_x0}
tmp_list.append(tmp_dict)
return tmp_x0 #input + eps * c_out
def get_special_sigmas_with_timesteps(self,timesteps):
low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
self.alphas_cumprod = self.alphas_cumprod.to('cpu')
alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
return ((1 - alphas) / alphas) ** 0.5
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_euler(model
, x
, sigmas
, extra_args=None
, callback=None
, disable=None
, s_churn=0.
, s_tmin=0.
, s_tmax=float('inf')
, tmp_list=[]
, uncond_list=[]
, need_distill_uncond=False
, start_free_step = 1
, noise_training_list={}
, s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
intermediates = {'x_inter': [x],'pred_x0': []}
register_free_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
for i in trange(len(sigmas) - 1, disable=disable):
if i == start_free_step:
register_free_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.3, b2=1.4, s1=0.9, s2=0.2)
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
intermediates['pred_x0'].append(denoised)
intermediates['x_inter'].append(x)
return prompt_embeds, intermediates, x
@torch.no_grad()
def sample_heun(model
, x
, sigmas
, extra_args=None
, callback=None
, disable=None
, s_churn=0.
, s_tmin=0.
, s_tmax=float('inf')
, tmp_list=[]
, uncond_list=[]
, need_distill_uncond=False
, noise_training_list={}
, s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
intermediates = {'x_inter': [x],'pred_x0': []}
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
_, denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
intermediates['pred_x0'].append(denoised_2)
intermediates['x_inter'].append(x)
return prompt_embeds, intermediates, x
@torch.no_grad()
def sample_dpmpp_ode(model
, x
, sigmas
, need_golden_noise = False
, start_free_step = 1
, extra_args=None, callback=None
, disable=None,tmp_list=[]
, need_distill_uncond=False
, need_raw_noise=False
, uncond_list=[]
, noise_training_list={}):
"""DPM-Solver++."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
if need_raw_noise:
x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=(sigmas[0] - 0.28) * s_in, noise_training_list=noise_training_list,**extra_args)
register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
intermediates = {'x_inter': [x],'pred_x0': []}
for i in trange(len(sigmas) - 1, disable=disable):
if i == start_free_step:
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
# macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
intermediates['pred_x0'].append(denoised)
intermediates['x_inter'].append(x)
# print(denoised_d.max())
# intermediates['noise'].append(denoised_d)
return prompt_embeds, intermediates,x
@torch.no_grad()
def sample_dpmpp_sde(model
, x
, sigmas
, need_golden_noise = False
, extra_args=None
, callback=None
, tmp_list=[]
, need_distill_uncond=False
, uncond_list=[]
, disable=None, eta=1.
, s_noise=1.
, noise_sampler=None
, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
if need_golden_noise:
x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
intermediates = {'x_inter': [x],'pred_x0': []}
for i in trange(len(sigmas) - 1, disable=disable):
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
intermediates['pred_x0'].append(denoised)
intermediates['x_inter'].append(x)
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
prompt_embeds, denoised_2 = model(x_2, sigma_fn(s) * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args) #(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
intermediates['pred_x0'].append(x)
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
intermediates['x_inter'].append(x)
return prompt_embeds, intermediates,x
@torch.no_grad()
def sample_dpmpp_2m(model
, x
, sigmas
# , need_golden_noise = True
, extra_args=None
, callback=None
, disable=None
, tmp_list=[]
, need_distill_uncond=False
, start_free_step=9
, uncond_list=[]
, stop_t = None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
# if need_golden_noise:
# x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
intermediates = {'x_inter': [x],'pred_x0': []}
register_free_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1, b2=1, s1=1, s2=1)
for i in trange(len(sigmas) - 1, disable=disable):
if i == start_free_step and len(sigmas) > 6:
register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
# else:
# register_free_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
# register_free_crossattn_upblock2d(model.inner_model.pipe, b1=1.1, b2=1.1, s1=1.0, s2=1.0)
# macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
prompt_embeds, denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
intermediates['pred_x0'].append(denoised)
intermediates['x_inter'].append(x)
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
intermediates['x_inter'].append(x)
intermediates['pred_x0'].append(denoised)
# print(denoised_d.max())
old_denoised = denoised
if i is not None and i == stop_t:
return intermediates, x
# intermediates['noise'].append(denoised_d)
return prompt_embeds, intermediates,x
# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
with torch.no_grad():
text_inputs = tokenizer(
captions,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
return prompt_embeds
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
torch_dtype = torch.float32
# device = "cuda"
# pipe = StableDiffusionPipeline.from_single_file( "./counterfeit/Counterfeit-V3.0_fp32.safetensors")
repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
filename = "sdxl_vae.safetensors" # e.g., "pytorch_model.bin"
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename,cache_dir=".")
# pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4')
vae = AutoencoderKL.from_single_file(downloaded_path, torch_dtype=torch_dtype)
vae.to('cuda')
pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0",torch_dtype=torch_dtype,vae=vae)
# pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,vae=vae)
pipe.to('cuda')
npn_net = NPNet128('SDXL', './sdxl.pth')
pipe = pipe.to(device,dtype=torch_dtype)
register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
noise_scheduler = pipe.scheduler
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=torch_dtype)
model_wrap = DiscreteEpsDDPMDenoiser(pipe, alpha_schedule, quantize=False)
accelerator = accelerate.Accelerator()
def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
"""Helper function to generate image with specific number of steps"""
prompts = [prompt]
if num_inference_steps <= 10:
register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
else:
register_free_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
register_free_crossattn_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}
compute_embeddings_fn = functools.partial(
compute_embeddings,
proportion_empty_prompts=0,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
)
generator = torch.Generator().manual_seed(seed)
intermediate_photos = list()
# prompts = prompts[0]
# if isinstance(prompts, tuple) or isinstance(prompts, str):
# prompts = list(prompts)
if isinstance(prompts, str):
prompts = prompts #+ 'high quality, best quality, masterpiece, 4K, highres, extremely detailed, ultra-detailed'
prompts = (prompts,)
if isinstance(prompts, tuple) or isinstance(prompts, str):
prompts = list(prompts)
shape = [4, height // 8, width // 8]
start_free_step = num_inference_steps
fir_stage_sigmas_ct = None
sec_stage_sigmas_ct = None
# sigmas = model_wrap.get_sigmas(opt.ddim_steps).to(device=device)
if num_inference_steps == 5:
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
sigmas = get_sigmas_karras(8, sigma_min, sigma_max, rho=5.0, device=device)# 6.0 if 5 else 10 10.0
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
# sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
ct = get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
elif num_inference_steps == 6:
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
sigmas = get_sigmas_karras(8, sigma_min, sigma_max,rho=5.0, device=device)# 6.0 if 5 else 10.0
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
# sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
ct = get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
start_free_step = 6
fir_stage_sigmas_ct = sigmas_ct[:-2]
sec_stage_sigmas_ct = sigmas_ct[-3:]
elif num_inference_steps == 8:
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
sigmas = get_sigmas_karras(12, sigma_min, sigma_max,rho=12.0, device=device)# 6.0 if 5 else 10.0
ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[10])
ct = get_sigmas_karras(num_inference_steps + 1, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
start_free_step = 8
else:
image = pipe(prompt=prompts
,num_inference_steps=num_inference_steps
,guidance_scale=guidance_scale
,height=height
,width=width).images[0]
return image
ts = []
for sigma in sigmas_ct:
t = model_wrap.sigma_to_t(sigma)
ts.append(t)
c_in = model_wrap.get_c_ins(sigmas=sigmas_ct)
x = torch.randn([1, *shape], device=device) * sigmas_ct[0]
model_wrap_cfg = CFGDenoiser(model_wrap)
(
c,
uc,
_,
_,
) = pipe.encode_prompt(
prompt=prompts,
device=device,
do_classifier_free_guidance=True,
)
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# if (num_inference_steps != -1 or num_inference_steps <= 8) and not opt.force_not_use_NPNet:
# x = npn_net(x,c)
extra_args = {'prompt': prompts, 'cond_scale': guidance_scale}
with torch.no_grad():
# with precision_scope("cuda" if torch.cuda.is_available() else "cpu"):
if not (num_inference_steps == 8 or num_inference_steps == 7):
prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_ode(model_wrap_cfg
, x
, fir_stage_sigmas_ct
, extra_args=extra_args
, disable=not accelerator.is_main_process
, need_raw_noise = False
, tmp_list=intermediate_photos)
_, _, samples_ddim = sample_euler(model_wrap_cfg
, samples_ddim
, sec_stage_sigmas_ct
, extra_args=extra_args
, disable=not accelerator.is_main_process
, s_noise = 0.3
, tmp_list=intermediate_photos)
else:
prompt_embeds, guide_distill, samples_ddim = sample_dpmpp_2m(model_wrap_cfg
, x
, sigmas_ct
, extra_args=extra_args
, start_free_step=start_free_step
, disable=not accelerator.is_main_process
, tmp_list=intermediate_photos)
# print('2m')
x_samples_ddim = pipe.vae.decode(samples_ddim / pipe.vae.config.scaling_factor).sample
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if True: # not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
# base_count += 1
# image = pipe(
# prompt=prompt,
# negative_prompt=negative_prompt,
# guidance_scale=guidance_scale,
# num_inference_steps=num_inference_steps,
# width=width,
# height=height,
# generator=generator,
# ).images[0]
return image
@spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
resolution,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Parse resolution string into width and height
width, height = map(int, resolution.split('x'))
# Generate image with selected steps
image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
# Generate image with 50 steps for high quality
image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
return image_quick, image_50_steps, seed
examples = [
"Astronaut in a jungle, cold color, muted colors, detailed, 8k",
"a painting of a virus monster playing guitar",
"a painting of a squirrel eating a burger",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Hyperparameters are all you need")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### Our fast inference Result")
result = gr.Image(label="Quick Result", show_label=False)
with gr.Column():
gr.Markdown("### Original 50 steps Result")
result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
resolution = gr.Dropdown(
choices=[
"1024x1024",
"1216x832",
"832x1216"
],
value="1024x1024",
label="Resolution",
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5, # Replace with defaults that work for your model
)
num_inference_steps = gr.Dropdown(
choices=[6, 8],
value=8,
label="Number of inference steps",
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
resolution,
guidance_scale,
num_inference_steps,
],
outputs=[result, result_50_steps, seed],
)
if __name__ == "__main__":
demo.launch()