diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 8e2627b6f4..adcc092a81 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to @@ -208,6 +212,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -380,12 +385,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): self._step_index = None - def get_scalings_for_boundary_condition_discrete(self, t): + def get_scalings_for_boundary_condition_discrete(self, timestep): self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.config.timestep_scaling - # By dividing 0.1: This is almost a delta function at t=0. - c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2) - c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5 + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out def step( @@ -466,9 +471,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): denoised = c_out * predicted_original_sample + c_skip * sample # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference - # Noise is not used for one-step sampling. - if len(self.timesteps) > 1: - noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device) + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype + ) prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise else: prev_sample = denoised diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py index 48b68fa47d..f7d511ff05 100644 --- a/tests/schedulers/test_scheduler_lcm.py +++ b/tests/schedulers/test_scheduler_lcm.py @@ -230,7 +230,7 @@ class LCMSchedulerTest(SchedulerCommonTest): result_mean = torch.mean(torch.abs(sample)) # TODO: get expected sum and mean - assert abs(result_sum.item() - 18.7097) < 1e-2 + assert abs(result_sum.item() - 18.7097) < 1e-3 assert abs(result_mean.item() - 0.0244) < 1e-3 def test_full_loop_multistep(self): @@ -240,5 +240,5 @@ class LCMSchedulerTest(SchedulerCommonTest): result_mean = torch.mean(torch.abs(sample)) # TODO: get expected sum and mean - assert abs(result_sum.item() - 280.5618) < 1e-2 - assert abs(result_mean.item() - 0.3653) < 1e-3 + assert abs(result_sum.item() - 197.7616) < 1e-3 + assert abs(result_mean.item() - 0.2575) < 1e-3