From 81331f3b7d9d05e27a10df181819ffbd6ef6e135 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:57:44 -0700 Subject: [PATCH 1/4] Add x-prediction / prediction_type=sample support for SDXL fine-tuning (#5095) Co-authored-by: bghira Co-authored-by: Sayak Paul --- examples/text_to_image/train_text_to_image_sdxl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 299d0f0d75..22486298c9 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -998,6 +998,11 @@ def main(args): target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) + elif noise_scheduler.config.prediction_type == "sample": + # We set the target to latents here, but the model_pred will return the noise sample prediction. + target = model_input + # We will have to subtract the noise residual from the prediction to get the target sample. + model_pred = model_pred - noise else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") From 74e43a4fbdff639dc3cc1a286a74b160b17d480b Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:31:27 -0700 Subject: [PATCH 2/4] Resolve v_prediction issue for min-SNR gamma weighted loss function (#5096) * Resolve v_prediction issue for min-SNR gamma weighted loss function * Combine MSE loss calculation of epsilon and velocity, with a note about the application of the epsilon code to sample prediction * style --------- Co-authored-by: bghira Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_sdxl.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 22486298c9..d37621e50f 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -332,15 +332,6 @@ def parse_args(input_args=None): help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", ) - parser.add_argument( - "--force_snr_gamma", - action="store_true", - help=( - "When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN" - " condition when computing the SNR with a sigma value of zero. This parameter overrides the check," - " allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results." - ), - ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--allow_tf32", @@ -554,18 +545,6 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Check for terminal SNR in combination with SNR Gamma - if ( - args.snr_gamma - and not args.force_snr_gamma - and ( - hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr - ) - ): - raise ValueError( - f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n" - "When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n" - "This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero." - ) text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) @@ -1013,9 +992,17 @@ def main(args): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. From 8263cf00f832399bca215e29fa7572e0b0bde4da Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 19 Sep 2023 11:21:49 -1000 Subject: [PATCH 3/4] refactor DPMSolverMultistepScheduler using sigmas (#4986) --------- Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen --- .../schedulers/scheduling_deis_multistep.py | 307 +++++++++++---- .../scheduling_dpmsolver_multistep.py | 303 ++++++++++----- .../scheduling_dpmsolver_multistep_inverse.py | 356 +++++++++++++----- .../scheduling_dpmsolver_singlestep.py | 310 +++++++++++---- .../schedulers/scheduling_unipc_multistep.py | 246 ++++++++---- tests/schedulers/test_scheduler_deis.py | 1 + tests/schedulers/test_scheduler_dpm_multi.py | 6 +- .../test_scheduler_dpm_multi_inverse.py | 3 +- tests/schedulers/test_scheduler_dpm_single.py | 31 ++ tests/schedulers/test_scheduler_unipc.py | 9 +- 10 files changed, 1165 insertions(+), 407 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 95d809575d..c7a94bce88 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,6 +22,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -186,6 +187,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -225,17 +234,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - - self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -245,6 +253,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -280,8 +291,57 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): return sample + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DEIS algorithm needs. @@ -298,13 +358,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -316,7 +389,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] return (sample - alpha_t * x0_pred) / sigma_t else: raise NotImplementedError("only support log-rho multistep deis now") @@ -324,9 +396,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): def deis_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DEIS (equivalent to DDIM). @@ -345,9 +417,33 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "deis": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output @@ -358,9 +454,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_deis_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DEIS. @@ -368,10 +464,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -379,10 +471,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] - sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 @@ -403,9 +523,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_deis_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DEIS. @@ -413,10 +533,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -424,15 +540,47 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] - sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] + rho_t, rho_s0, rho_s1, rho_s2 = ( sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1, - simga_s2 / alpha_s2, + sigma_s2 / alpha_s2, ) if self.config.algorithm_type == "deis": @@ -460,6 +608,25 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError("only support log-rho multistep deis now") + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -492,42 +659,34 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) + prev_sample = self.deis_first_order_update(model_output, sample=sample) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -548,28 +707,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index babba2206d..264ee268ae 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -203,6 +204,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -242,19 +251,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - - self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -264,6 +273,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -323,6 +335,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -338,7 +356,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is @@ -355,8 +377,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -364,6 +384,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: @@ -371,12 +403,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -398,10 +432,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): else: epsilon = model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( @@ -410,7 +446,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t @@ -420,10 +457,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -431,10 +468,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -442,9 +475,33 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output @@ -469,10 +526,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DPMSolver. @@ -480,10 +537,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -491,11 +544,43 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) @@ -564,9 +649,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DPMSolver. @@ -574,10 +659,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -585,16 +666,47 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 @@ -619,6 +731,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ) return x_t + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -654,22 +785,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -682,23 +808,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update( - model_output, timestep, prev_timestep, sample, noise=noise - ) + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample, noise=noise - ) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -719,28 +840,30 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 33a2637d00..7c740234fa 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -203,8 +204,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self._step_index = None self.use_karras_sigmas = use_karras_sigmas + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -244,11 +253,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.use_karras_sigmas: - log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) @@ -266,6 +283,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ] * self.config.solver_order self.lower_order_nums = 0 + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -325,6 +345,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -341,7 +368,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is @@ -358,8 +389,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -367,6 +396,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: @@ -374,12 +415,14 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -401,10 +444,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): else: epsilon = model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( @@ -413,20 +458,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -434,10 +481,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -445,27 +488,62 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif "sde" in self.config.algorithm_type: - raise NotImplementedError( - f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, noise: Optional[torch.FloatTensor] = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order multistep DPMSolver. @@ -473,10 +551,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -484,11 +558,43 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) @@ -520,19 +626,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) - elif "sde" in self.config.algorithm_type: - raise NotImplementedError( - f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." - ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order multistep DPMSolver. @@ -540,10 +674,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): Args: model_output_list (`List[torch.FloatTensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by diffusion process. @@ -551,16 +681,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 @@ -585,6 +746,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ) return x_t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step def step( self, model_output: torch.FloatTensor, @@ -604,6 +786,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. @@ -618,24 +802,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = ( - self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - ) + if self.step_index is None: + self._init_step_index(timestep) + lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -648,23 +825,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update( - model_output, timestep, prev_timestep, sample, noise=noise - ) + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample, noise=noise - ) + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -686,28 +858,30 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 060ec363e8..10f7ab34e0 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,7 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import logging +from ..utils import deprecate, logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -232,6 +233,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): orders = [1] * steps return orders + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -256,11 +264,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order @@ -274,6 +287,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self.order_list = self.get_order_list(num_inference_steps) + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -333,6 +349,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -348,7 +371,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is @@ -365,8 +392,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.FloatTensor`): The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. @@ -374,18 +399,32 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -405,11 +444,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: @@ -421,9 +462,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the first-order DPMSolver (equivalent to DDIM). @@ -442,9 +483,31 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output @@ -455,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the @@ -477,11 +540,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] - sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m1, (1.0 / r0) * (m0 - m1) @@ -518,9 +612,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the @@ -540,16 +634,47 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], ) - alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] - sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m2 @@ -591,10 +716,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def singlestep_dpm_solver_update( self, model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - order: int, + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the singlestep DPMSolver. @@ -615,19 +740,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) elif order == 2: - return self.singlestep_dpm_solver_second_order_update( - model_output_list, timestep_list, prev_timestep, sample - ) + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) elif order == 3: - return self.singlestep_dpm_solver_third_order_update( - model_output_list, timestep_list, prev_timestep, sample - ) + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -660,21 +826,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + if self.step_index is None: + self._init_step_index(timestep) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - order = self.order_list[step_index] + order = self.order_list[self.step_index] # For img2img denoising might start with order>1 which is not possible # In this case make sure that the first two steps are both order=1 @@ -685,10 +845,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): if order == 1: self.sample = sample - timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] - prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, timestep_list, prev_timestep, self.sample, order - ) + prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order) + + # upon completion increase step index by one + self._step_index += 1 if not return_dict: return (prev_sample,) @@ -710,28 +870,30 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 1525845db3..2dcca2ecae 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -22,10 +22,16 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -38,19 +44,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -181,6 +198,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None + self._step_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -220,17 +245,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - timesteps = np.flip(timesteps).copy().astype(np.int64) - - self.sigmas = torch.from_numpy(sigmas) - - # when num_inference_steps == num_train_timesteps, we can end up with - # duplicates in timesteps. - _, unique_indices = np.unique(timesteps, return_index=True) - timesteps = timesteps[np.sort(unique_indices)] + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) @@ -243,6 +267,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ @@ -302,6 +329,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" @@ -317,7 +351,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): return sigmas def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, ) -> torch.FloatTensor: r""" Convert the model output to the corresponding type the UniPC algorithm needs. @@ -334,14 +372,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The converted model output. """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + if self.predict_x0: if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -357,11 +409,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: @@ -373,9 +423,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, - prev_timestep: int, - sample: torch.FloatTensor, - order: int, + *args, + sample: torch.FloatTensor = None, + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. @@ -394,10 +445,26 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The sample tensor at the previous timestep. """ - timestep_list = self.timestep_list + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) model_output_list = self.model_outputs - s0, t = self.timestep_list[-1], prev_timestep + s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample @@ -405,9 +472,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = sample.device @@ -415,9 +485,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): rks = [] D1s = [] for i in range(1, order): - si = timestep_list[-(i + 1)] + si = self.step_index - i mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) @@ -481,10 +552,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, - this_timestep: int, - last_sample: torch.FloatTensor, - this_sample: torch.FloatTensor, - order: int, + *args, + last_sample: torch.FloatTensor = None, + this_sample: torch.FloatTensor = None, + order: int = None, + **kwargs, ) -> torch.FloatTensor: """ One step for the UniC (B(h) version). @@ -505,18 +577,42 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: The corrected sample tensor at the current timestep. """ - timestep_list = self.timestep_list + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs - s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = this_sample.device @@ -524,9 +620,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): rks = [] D1s = [] for i in range(1, order): - si = timestep_list[-(i + 1)] + si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) @@ -589,6 +686,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): x_t = x_t.to(x.dtype) return x_t + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + def step( self, model_output: torch.FloatTensor, @@ -616,37 +732,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): tuple is returned where the first element is the sample tensor. """ - if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() + if self.step_index is None: + self._init_step_index(timestep) use_corrector = ( - step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) - model_output_convert = self.convert_model_output(model_output, timestep, sample) + model_output_convert = self.convert_model_output(model_output, sample=sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, - this_timestep=timestep, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) - # now prepare to run the predictor - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] @@ -655,7 +761,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self.timestep_list[-1] = timestep if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - step_index) + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) else: this_order = self.config.solver_order @@ -665,7 +771,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used - prev_timestep=prev_timestep, sample=sample, order=self.this_order, ) @@ -673,6 +778,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,) @@ -693,28 +801,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ return sample - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index 8b14601bc9..277aaf26e4 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -51,6 +51,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c9935780b9..6f3c818457 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -59,6 +59,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = new_scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -91,6 +92,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): # copy over dummy past residual (must be after setting timesteps) new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + time_step = new_scheduler.timesteps[time_step] output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample @@ -264,10 +266,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): assert sample.dtype == torch.float16 - def test_unique_timesteps(self, **config): + def test_duplicated_timesteps(self, **config): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + assert len(scheduler.timesteps) == scheduler.num_inference_steps diff --git a/tests/schedulers/test_scheduler_dpm_multi_inverse.py b/tests/schedulers/test_scheduler_dpm_multi_inverse.py index 61a1d82e0f..014c901680 100644 --- a/tests/schedulers/test_scheduler_dpm_multi_inverse.py +++ b/tests/schedulers/test_scheduler_dpm_multi_inverse.py @@ -54,6 +54,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -222,7 +223,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 1.7833) < 1e-3 + assert abs(result_mean.item() - 1.7833) < 2e-3 def test_switch(self): # make sure that iterating over schedulers with same config names gives same results diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 66be3d5d00..169839e776 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -58,6 +58,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -248,3 +249,33 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + time_step_0 = scheduler.timesteps[0] + time_step_1 = scheduler.timesteps[1] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 0495f423e5..08482fd06b 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -52,6 +52,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): output, new_output = sample, sample for t in range(time_step, time_step + scheduler.config.solver_order + 1): + t = scheduler.timesteps[t] output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -241,11 +242,3 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 - - def test_unique_timesteps(self, **config): - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - - scheduler.set_timesteps(scheduler.config.num_train_timesteps) - assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps From e312b2302b5445271198ebed8f2fbcd543633f78 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 20 Sep 2023 10:30:18 +0100 Subject: [PATCH 4/4] [LoRA] support LyCORIS (#5102) * better condition. * debugging * how about now? * how about now? * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * support for lycoris. * style * add: lycoris test * fix from_pretrained call. * fix assertion values. --- src/diffusers/loaders.py | 9 ++++++++- src/diffusers/models/embeddings.py | 5 +++-- tests/lora/test_lora_layers.py | 19 +++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ec77718d16..bea6e21aa6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1878,7 +1878,7 @@ class LoraLoaderMixin: diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") # SDXL specificity. - if "emb" in diffusers_name: + if "emb" in diffusers_name and "time" not in diffusers_name: pattern = r"\.\d+(?=\D*$)" diffusers_name = re.sub(pattern, "", diffusers_name, count=1) if ".in." in diffusers_name: @@ -1890,6 +1890,13 @@ class LoraLoaderMixin: if "skip" in diffusers_name: diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + # LyCORIS specificity. + if "time" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General coverage. if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3bdd758117..e05092de3d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,6 +19,7 @@ import torch from torch import nn from .activations import get_activation +from .lora import LoRACompatibleLinear def get_timestep_embedding( @@ -166,7 +167,7 @@ class TimestepEmbedding(nn.Module): ): super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) @@ -179,7 +180,7 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None diff --git a/tests/lora/test_lora_layers.py b/tests/lora/test_lora_layers.py index e54caeb9f0..20d44e0c07 100644 --- a/tests/lora/test_lora_layers.py +++ b/tests/lora/test_lora_layers.py @@ -1876,6 +1876,25 @@ class LoraIntegrationTests(unittest.TestCase): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_lycoris(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16" + ).to(torch_device) + lora_model_id = "hf-internal-testing/edgLycorisMugler-light" + lora_filename = "edgLycorisMugler-light.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_a1111_with_model_cpu_offload(self): generator = torch.Generator().manual_seed(0)