| | import torch |
| | import os |
| | import PIL |
| |
|
| | from typing import List, Optional, Union |
| | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput |
| | from diffusers.utils import logging |
| |
|
| | VECTOR_DATA_FOLDER = "vector_data" |
| | VECTOR_DATA_DICT = "vector_data" |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | def get_ddpm_inversion_scheduler( |
| | scheduler, |
| | step_function, |
| | config, |
| | timesteps, |
| | save_timesteps, |
| | latents, |
| | x_ts, |
| | x_ts_c_hat, |
| | save_intermediate_results, |
| | pipe, |
| | x_0, |
| | v1s_images, |
| | v2s_images, |
| | deltas_images, |
| | v1_x0s, |
| | v2_x0s, |
| | deltas_x0s, |
| | folder_name, |
| | image_name, |
| | time_measure_n, |
| | ): |
| | def step( |
| | model_output: torch.FloatTensor, |
| | timestep: int, |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | generator=None, |
| | variance_noise: Optional[torch.FloatTensor] = None, |
| | return_dict: bool = True, |
| | ): |
| | |
| | |
| | res_inv = step_save_latents( |
| | scheduler, |
| | model_output[:1, :, :, :], |
| | timestep, |
| | sample[:1, :, :, :], |
| | eta, |
| | use_clipped_model_output, |
| | generator, |
| | variance_noise, |
| | return_dict, |
| | ) |
| | |
| | |
| |
|
| | res_inf = step_use_latents( |
| | scheduler, |
| | model_output[1:, :, :, :], |
| | timestep, |
| | sample[1:, :, :, :], |
| | eta, |
| | use_clipped_model_output, |
| | generator, |
| | variance_noise, |
| | return_dict, |
| | ) |
| | |
| | res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) |
| | return res |
| | |
| |
|
| | scheduler.step_function = step_function |
| | scheduler.is_save = True |
| | scheduler._timesteps = timesteps |
| | scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps |
| | scheduler._config = config |
| | scheduler.latents = latents |
| | scheduler.x_ts = x_ts |
| | scheduler.x_ts_c_hat = x_ts_c_hat |
| | scheduler.step = step |
| | scheduler.save_intermediate_results = save_intermediate_results |
| | scheduler.pipe = pipe |
| | scheduler.v1s_images = v1s_images |
| | scheduler.v2s_images = v2s_images |
| | scheduler.deltas_images = deltas_images |
| | scheduler.v1_x0s = v1_x0s |
| | scheduler.v2_x0s = v2_x0s |
| | scheduler.deltas_x0s = deltas_x0s |
| | scheduler.clean_step_run = False |
| | scheduler.x_0s = create_xts( |
| | config.noise_shift_delta, |
| | config.noise_timesteps, |
| | config.clean_step_timestep, |
| | None, |
| | pipe.scheduler, |
| | timesteps, |
| | x_0, |
| | no_add_noise=True, |
| | ) |
| | scheduler.folder_name = folder_name |
| | scheduler.image_name = image_name |
| | scheduler.p_to_p = False |
| | scheduler.p_to_p_replace = False |
| | scheduler.time_measure_n = time_measure_n |
| | return scheduler |
| |
|
| | def step_save_latents( |
| | self, |
| | model_output: torch.FloatTensor, |
| | timestep: int, |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | generator=None, |
| | variance_noise: Optional[torch.FloatTensor] = None, |
| | return_dict: bool = True, |
| | ): |
| | |
| | |
| | |
| | timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1 |
| | next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1 |
| | u_hat_t = self.step_function( |
| | model_output=model_output, |
| | timestep=timestep, |
| | sample=sample, |
| | eta=eta, |
| | use_clipped_model_output=use_clipped_model_output, |
| | generator=generator, |
| | variance_noise=variance_noise, |
| | return_dict=False, |
| | scheduler=self, |
| | ) |
| |
|
| | x_t_minus_1 = self.x_ts[next_timestep_index] |
| | self.x_ts_c_hat.append(u_hat_t) |
| |
|
| | z_t = x_t_minus_1 - u_hat_t |
| | self.latents.append(z_t) |
| | z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs) |
| |
|
| | x_t_minus_1_predicted = u_hat_t + z_t |
| |
|
| | if not return_dict: |
| | return (x_t_minus_1_predicted,) |
| |
|
| | return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) |
| |
|
| | def step_use_latents( |
| | self, |
| | model_output: torch.FloatTensor, |
| | timestep: int, |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | generator=None, |
| | variance_noise: Optional[torch.FloatTensor] = None, |
| | return_dict: bool = True, |
| | ): |
| | |
| | timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1 |
| | next_timestep_index = ( |
| | timestep_index + 1 if not self.clean_step_run else -1 |
| | ) |
| | z_t = self.latents[next_timestep_index] |
| |
|
| | _, normalize_coefficient = normalize( |
| | z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t, |
| | timestep_index, |
| | self._config.max_norm_zs, |
| | ) |
| |
|
| | if normalize_coefficient == 0: |
| | eta = 0 |
| |
|
| | |
| |
|
| | x_t_hat_c_hat = self.step_function( |
| | model_output=model_output, |
| | timestep=timestep, |
| | sample=sample, |
| | eta=eta, |
| | use_clipped_model_output=use_clipped_model_output, |
| | generator=generator, |
| | variance_noise=variance_noise, |
| | return_dict=False, |
| | scheduler=self, |
| | ) |
| |
|
| | w1 = self._config.ws1[timestep_index] |
| | w2 = self._config.ws2[timestep_index] |
| |
|
| | x_t_minus_1_exact = self.x_ts[next_timestep_index] |
| | x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) |
| |
|
| | x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index] |
| | if self._config.breakdown == "x_t_c_hat": |
| | raise NotImplementedError("breakdown x_t_c_hat not implemented yet") |
| |
|
| | |
| | x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat) |
| |
|
| | |
| | |
| | |
| | if ( |
| | self._config.breakdown == "x_t_hat_c" |
| | or self._config.breakdown == "x_t_hat_c_with_zeros" |
| | ): |
| | zero_index_reconstruction = 1 if not self.time_measure_n else 0 |
| | edit_prompts_num = ( |
| | (model_output.size(0) - zero_index_reconstruction) // 3 |
| | if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p |
| | else (model_output.size(0) - zero_index_reconstruction) // 2 |
| | ) |
| | x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction) |
| | edit_images_indices = ( |
| | edit_prompts_num + zero_index_reconstruction, |
| | ( |
| | model_output.size(0) |
| | if self._config.breakdown == "x_t_hat_c" |
| | else zero_index_reconstruction + 2 * edit_prompts_num |
| | ), |
| | ) |
| | x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) |
| | x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ |
| | x_t_hat_c_indices[0] : x_t_hat_c_indices[1] |
| | ] |
| | v1 = x_t_hat_c_hat - x_t_hat_c |
| | v2 = x_t_hat_c - normalize_coefficient * x_t_c |
| | if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: |
| | path = os.path.join( |
| | self.folder_name, |
| | VECTOR_DATA_FOLDER, |
| | self.image_name, |
| | ) |
| | if not hasattr(self, VECTOR_DATA_DICT): |
| | os.makedirs(path, exist_ok=True) |
| | self.vector_data = dict() |
| |
|
| | x_t_0 = x_t_c_hat[1] |
| | empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num) |
| | x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]] |
| |
|
| | self.vector_data[timestep.item()] = dict() |
| | self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[ |
| | edit_images_indices[0] : edit_images_indices[1] |
| | ] |
| | self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0 |
| | self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0) |
| | self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0) |
| | self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[ |
| | edit_images_indices[0] : edit_images_indices[1] |
| | ] |
| | self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[ |
| | 0 |
| | ].expand_as(x_t_hat_0) |
| | self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[ |
| | next_timestep_index |
| | ].expand_as(x_t_hat_0) |
| |
|
| | else: |
| | v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c |
| | v2 = 0 |
| |
|
| | if self.save_intermediate_results and not self.p_to_p: |
| | delta = v1 + v2 |
| | v1_plus_x0 = self.x_0s[next_timestep_index] + v1 |
| | v2_plus_x0 = self.x_0s[next_timestep_index] + v2 |
| | delta_plus_x0 = self.x_0s[next_timestep_index] + delta |
| |
|
| | v1_images = decode_latents(v1, self.pipe) |
| | self.v1s_images.append(v1_images) |
| | v2_images = ( |
| | decode_latents(v2, self.pipe) |
| | if self._config.breakdown != "no_breakdown" |
| | else [PIL.Image.new("RGB", (1, 1))] |
| | ) |
| | self.v2s_images.append(v2_images) |
| | delta_images = decode_latents(delta, self.pipe) |
| | self.deltas_images.append(delta_images) |
| | v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe) |
| | self.v1_x0s.append(v1_plus_x0_images) |
| | v2_plus_x0_images = ( |
| | decode_latents(v2_plus_x0, self.pipe) |
| | if self._config.breakdown != "no_breakdown" |
| | else [PIL.Image.new("RGB", (1, 1))] |
| | ) |
| | self.v2_x0s.append(v2_plus_x0_images) |
| | delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe) |
| | self.deltas_x0s.append(delta_plus_x0_images) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2 |
| |
|
| | if ( |
| | self._config.breakdown == "x_t_hat_c" |
| | or self._config.breakdown == "x_t_hat_c_with_zeros" |
| | ): |
| | x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[ |
| | edit_images_indices[0] : edit_images_indices[1] |
| | ] |
| | if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: |
| | x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = ( |
| | x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]] |
| | ) |
| | self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[ |
| | edit_images_indices[0] : edit_images_indices[1] |
| | ] |
| | if timestep == self._timesteps[-1]: |
| | torch.save( |
| | self.vector_data, |
| | os.path.join( |
| | path, |
| | f"{VECTOR_DATA_DICT}.pt", |
| | ), |
| | ) |
| | |
| | if not self.time_measure_n: |
| | x_t_minus_1[0] = x_t_minus_1_exact[0] |
| |
|
| | if not return_dict: |
| | return (x_t_minus_1,) |
| |
|
| | return DDIMSchedulerOutput( |
| | prev_sample=x_t_minus_1, |
| | pred_original_sample=None, |
| | ) |
| |
|
| | def create_xts( |
| | noise_shift_delta, |
| | noise_timesteps, |
| | clean_step_timestep, |
| | generator, |
| | scheduler, |
| | timesteps, |
| | x_0, |
| | no_add_noise=False, |
| | ): |
| | if noise_timesteps is None: |
| | noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) |
| | noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] |
| |
|
| | first_x_0_idx = len(noise_timesteps) |
| | for i in range(len(noise_timesteps)): |
| | if noise_timesteps[i] <= 0: |
| | first_x_0_idx = i |
| | break |
| |
|
| | noise_timesteps = noise_timesteps[:first_x_0_idx] |
| |
|
| | x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) |
| | noise = ( |
| | torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to( |
| | x_0.device |
| | ) |
| | if not no_add_noise |
| | else torch.zeros_like(x_0_expanded) |
| | ) |
| | x_ts = scheduler.add_noise( |
| | x_0_expanded, |
| | noise, |
| | torch.IntTensor(noise_timesteps), |
| | ) |
| | x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] |
| | x_ts += [x_0] * (len(timesteps) - first_x_0_idx) |
| | x_ts += [x_0] |
| | if clean_step_timestep > 0: |
| | x_ts += [x_0] |
| | return x_ts |
| |
|
| | def normalize( |
| | z_t, |
| | i, |
| | max_norm_zs, |
| | ): |
| | max_norm = max_norm_zs[i] |
| | if max_norm < 0: |
| | return z_t, 1 |
| |
|
| | norm = torch.norm(z_t) |
| | if norm < max_norm: |
| | return z_t, 1 |
| |
|
| | coeff = max_norm / norm |
| | z_t = z_t * coeff |
| | return z_t, coeff |
| |
|
| | def decode_latents(latent, pipe): |
| | latent_img = pipe.vae.decode( |
| | latent / pipe.vae.config.scaling_factor, return_dict=False |
| | )[0] |
| | return pipe.image_processor.postprocess(latent_img, output_type="pil") |
| |
|
| | def deterministic_ddim_step( |
| | model_output: torch.FloatTensor, |
| | timestep: int, |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | generator=None, |
| | variance_noise: Optional[torch.FloatTensor] = None, |
| | return_dict: bool = True, |
| | scheduler=None, |
| | ): |
| |
|
| | if scheduler.num_inference_steps is None: |
| | raise ValueError( |
| | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| | ) |
| |
|
| | prev_timestep = ( |
| | timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps |
| | ) |
| |
|
| | |
| | alpha_prod_t = scheduler.alphas_cumprod[timestep] |
| | alpha_prod_t_prev = ( |
| | scheduler.alphas_cumprod[prev_timestep] |
| | if prev_timestep >= 0 |
| | else scheduler.final_alpha_cumprod |
| | ) |
| |
|
| | beta_prod_t = 1 - alpha_prod_t |
| |
|
| | if scheduler.config.prediction_type == "epsilon": |
| | pred_original_sample = ( |
| | sample - beta_prod_t ** (0.5) * model_output |
| | ) / alpha_prod_t ** (0.5) |
| | pred_epsilon = model_output |
| | elif scheduler.config.prediction_type == "sample": |
| | pred_original_sample = model_output |
| | pred_epsilon = ( |
| | sample - alpha_prod_t ** (0.5) * pred_original_sample |
| | ) / beta_prod_t ** (0.5) |
| | elif scheduler.config.prediction_type == "v_prediction": |
| | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
| | beta_prod_t**0.5 |
| | ) * model_output |
| | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
| | else: |
| | raise ValueError( |
| | f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or" |
| | " `v_prediction`" |
| | ) |
| |
|
| | |
| | if scheduler.config.thresholding: |
| | pred_original_sample = scheduler._threshold_sample(pred_original_sample) |
| | elif scheduler.config.clip_sample: |
| | pred_original_sample = pred_original_sample.clamp( |
| | -scheduler.config.clip_sample_range, |
| | scheduler.config.clip_sample_range, |
| | ) |
| |
|
| | |
| | |
| | variance = scheduler._get_variance(timestep, prev_timestep) |
| | std_dev_t = eta * variance ** (0.5) |
| |
|
| | if use_clipped_model_output: |
| | |
| | pred_epsilon = ( |
| | sample - alpha_prod_t ** (0.5) * pred_original_sample |
| | ) / beta_prod_t ** (0.5) |
| |
|
| | |
| | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( |
| | 0.5 |
| | ) * pred_epsilon |
| |
|
| | |
| | prev_sample = ( |
| | alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
| | ) |
| | return prev_sample |
| |
|
| |
|
| | def deterministic_euler_step( |
| | model_output: torch.FloatTensor, |
| | timestep: Union[float, torch.FloatTensor], |
| | sample: torch.FloatTensor, |
| | eta, |
| | use_clipped_model_output, |
| | generator, |
| | variance_noise, |
| | return_dict, |
| | scheduler, |
| | ): |
| | """ |
| | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| | process from the learned model outputs (most often the predicted noise). |
| | |
| | Args: |
| | model_output (`torch.FloatTensor`): |
| | The direct output from learned diffusion model. |
| | timestep (`float`): |
| | The current discrete timestep in the diffusion chain. |
| | sample (`torch.FloatTensor`): |
| | A current instance of a sample created by the diffusion process. |
| | generator (`torch.Generator`, *optional*): |
| | A random number generator. |
| | return_dict (`bool`): |
| | Whether or not to return a |
| | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. |
| | |
| | Returns: |
| | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: |
| | If return_dict is `True`, |
| | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, |
| | otherwise a tuple is returned where the first element is the sample tensor. |
| | |
| | """ |
| |
|
| | if ( |
| | isinstance(timestep, int) |
| | or isinstance(timestep, torch.IntTensor) |
| | or isinstance(timestep, torch.LongTensor) |
| | ): |
| | raise ValueError( |
| | ( |
| | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" |
| | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" |
| | " one of the `scheduler.timesteps` as a timestep." |
| | ), |
| | ) |
| |
|
| | if scheduler.step_index is None: |
| | scheduler._init_step_index(timestep) |
| |
|
| | sigma = scheduler.sigmas[scheduler.step_index] |
| |
|
| | |
| | sample = sample.to(torch.float32) |
| |
|
| | |
| | if scheduler.config.prediction_type == "epsilon": |
| | pred_original_sample = sample - sigma * model_output |
| | elif scheduler.config.prediction_type == "v_prediction": |
| | |
| | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( |
| | sample / (sigma**2 + 1) |
| | ) |
| | elif scheduler.config.prediction_type == "sample": |
| | raise NotImplementedError("prediction_type not implemented yet: sample") |
| | else: |
| | raise ValueError( |
| | f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" |
| | ) |
| |
|
| | sigma_from = scheduler.sigmas[scheduler.step_index] |
| | sigma_to = scheduler.sigmas[scheduler.step_index + 1] |
| | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 |
| | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 |
| |
|
| | |
| | derivative = (sample - pred_original_sample) / sigma |
| |
|
| | dt = sigma_down - sigma |
| |
|
| | prev_sample = sample + derivative * dt |
| |
|
| | |
| | prev_sample = prev_sample.to(model_output.dtype) |
| |
|
| | |
| | scheduler._step_index += 1 |
| |
|
| | return prev_sample |
| |
|
| |
|
| | def deterministic_non_ancestral_euler_step( |
| | model_output: torch.FloatTensor, |
| | timestep: Union[float, torch.FloatTensor], |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | s_churn: float = 0.0, |
| | s_tmin: float = 0.0, |
| | s_tmax: float = float("inf"), |
| | s_noise: float = 1.0, |
| | generator: Optional[torch.Generator] = None, |
| | variance_noise: Optional[torch.FloatTensor] = None, |
| | return_dict: bool = True, |
| | scheduler=None, |
| | ): |
| | """ |
| | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| | process from the learned model outputs (most often the predicted noise). |
| | |
| | Args: |
| | model_output (`torch.FloatTensor`): |
| | The direct output from learned diffusion model. |
| | timestep (`float`): |
| | The current discrete timestep in the diffusion chain. |
| | sample (`torch.FloatTensor`): |
| | A current instance of a sample created by the diffusion process. |
| | s_churn (`float`): |
| | s_tmin (`float`): |
| | s_tmax (`float`): |
| | s_noise (`float`, defaults to 1.0): |
| | Scaling factor for noise added to the sample. |
| | generator (`torch.Generator`, *optional*): |
| | A random number generator. |
| | return_dict (`bool`): |
| | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or |
| | tuple. |
| | |
| | Returns: |
| | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: |
| | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is |
| | returned, otherwise a tuple is returned where the first element is the sample tensor. |
| | """ |
| |
|
| | if ( |
| | isinstance(timestep, int) |
| | or isinstance(timestep, torch.IntTensor) |
| | or isinstance(timestep, torch.LongTensor) |
| | ): |
| | raise ValueError( |
| | ( |
| | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" |
| | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" |
| | " one of the `scheduler.timesteps` as a timestep." |
| | ), |
| | ) |
| |
|
| | if not scheduler.is_scale_input_called: |
| | logger.warning( |
| | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " |
| | "See `StableDiffusionPipeline` for a usage example." |
| | ) |
| |
|
| | if scheduler.step_index is None: |
| | scheduler._init_step_index(timestep) |
| |
|
| | |
| | sample = sample.to(torch.float32) |
| |
|
| | sigma = scheduler.sigmas[scheduler.step_index] |
| |
|
| | gamma = ( |
| | min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1) |
| | if s_tmin <= sigma <= s_tmax |
| | else 0.0 |
| | ) |
| |
|
| | sigma_hat = sigma * (gamma + 1) |
| |
|
| | |
| | |
| | |
| | if ( |
| | scheduler.config.prediction_type == "original_sample" |
| | or scheduler.config.prediction_type == "sample" |
| | ): |
| | pred_original_sample = model_output |
| | elif scheduler.config.prediction_type == "epsilon": |
| | pred_original_sample = sample - sigma_hat * model_output |
| | elif scheduler.config.prediction_type == "v_prediction": |
| | |
| | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( |
| | sample / (sigma**2 + 1) |
| | ) |
| | else: |
| | raise ValueError( |
| | f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" |
| | ) |
| |
|
| | |
| | derivative = (sample - pred_original_sample) / sigma_hat |
| |
|
| | dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat |
| |
|
| | prev_sample = sample + derivative * dt |
| |
|
| | |
| | prev_sample = prev_sample.to(model_output.dtype) |
| |
|
| | |
| | scheduler._step_index += 1 |
| |
|
| | return prev_sample |
| |
|
| |
|
| | def deterministic_ddpm_step( |
| | model_output: torch.FloatTensor, |
| | timestep: Union[float, torch.FloatTensor], |
| | sample: torch.FloatTensor, |
| | eta, |
| | use_clipped_model_output, |
| | generator, |
| | variance_noise, |
| | return_dict, |
| | scheduler, |
| | ): |
| | """ |
| | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| | process from the learned model outputs (most often the predicted noise). |
| | |
| | Args: |
| | model_output (`torch.FloatTensor`): |
| | The direct output from learned diffusion model. |
| | timestep (`float`): |
| | The current discrete timestep in the diffusion chain. |
| | sample (`torch.FloatTensor`): |
| | A current instance of a sample created by the diffusion process. |
| | generator (`torch.Generator`, *optional*): |
| | A random number generator. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. |
| | |
| | Returns: |
| | [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: |
| | If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a |
| | tuple is returned where the first element is the sample tensor. |
| | |
| | """ |
| | t = timestep |
| |
|
| | prev_t = scheduler.previous_timestep(t) |
| |
|
| | if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ |
| | "learned", |
| | "learned_range", |
| | ]: |
| | model_output, predicted_variance = torch.split( |
| | model_output, sample.shape[1], dim=1 |
| | ) |
| | else: |
| | predicted_variance = None |
| |
|
| | |
| | alpha_prod_t = scheduler.alphas_cumprod[t] |
| | alpha_prod_t_prev = ( |
| | scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one |
| | ) |
| | beta_prod_t = 1 - alpha_prod_t |
| | beta_prod_t_prev = 1 - alpha_prod_t_prev |
| | current_alpha_t = alpha_prod_t / alpha_prod_t_prev |
| | current_beta_t = 1 - current_alpha_t |
| |
|
| | |
| | |
| | if scheduler.config.prediction_type == "epsilon": |
| | pred_original_sample = ( |
| | sample - beta_prod_t ** (0.5) * model_output |
| | ) / alpha_prod_t ** (0.5) |
| | elif scheduler.config.prediction_type == "sample": |
| | pred_original_sample = model_output |
| | elif scheduler.config.prediction_type == "v_prediction": |
| | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
| | beta_prod_t**0.5 |
| | ) * model_output |
| | else: |
| | raise ValueError( |
| | f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" |
| | " `v_prediction` for the DDPMScheduler." |
| | ) |
| |
|
| | |
| | if scheduler.config.thresholding: |
| | pred_original_sample = scheduler._threshold_sample(pred_original_sample) |
| | elif scheduler.config.clip_sample: |
| | pred_original_sample = pred_original_sample.clamp( |
| | -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range |
| | ) |
| |
|
| | |
| | |
| | pred_original_sample_coeff = ( |
| | alpha_prod_t_prev ** (0.5) * current_beta_t |
| | ) / beta_prod_t |
| | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t |
| |
|
| | |
| | |
| | pred_prev_sample = ( |
| | pred_original_sample_coeff * pred_original_sample |
| | + current_sample_coeff * sample |
| | ) |
| |
|
| | return pred_prev_sample |
| |
|