mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve LCMScheduler (#5681)
* Refactor LCMScheduler.step such that prev_sample == denoised at the last timestep in the schedule. * Make timestep scaling when calculating boundary conditions configurable. * Reparameterize timestep_scaling to be a multiplicative rather than division scaling. * make style * fix dtype conversion * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
timestep_scaling (`float`, defaults to 10.0):
|
||||
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
|
||||
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
|
||||
error at the default of `10.0` is already pretty small).
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
@@ -208,6 +212,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_scaling: float = 10.0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -380,12 +385,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def get_scalings_for_boundary_condition_discrete(self, t):
|
||||
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
||||
self.sigma_data = 0.5 # Default: 0.5
|
||||
scaled_timestep = timestep * self.config.timestep_scaling
|
||||
|
||||
# By dividing 0.1: This is almost a delta function at t=0.
|
||||
c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
|
||||
c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
|
||||
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def step(
|
||||
@@ -466,9 +471,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
denoised = c_out * predicted_original_sample + c_skip * sample
|
||||
|
||||
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
||||
# Noise is not used for one-step sampling.
|
||||
if len(self.timesteps) > 1:
|
||||
noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device)
|
||||
# Noise is not used on the final timestep of the timestep schedule.
|
||||
# This also means that noise is not used for one-step sampling.
|
||||
if self.step_index != self.num_inference_steps - 1:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
|
||||
)
|
||||
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
||||
else:
|
||||
prev_sample = denoised
|
||||
|
||||
@@ -230,7 +230,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
# TODO: get expected sum and mean
|
||||
assert abs(result_sum.item() - 18.7097) < 1e-2
|
||||
assert abs(result_sum.item() - 18.7097) < 1e-3
|
||||
assert abs(result_mean.item() - 0.0244) < 1e-3
|
||||
|
||||
def test_full_loop_multistep(self):
|
||||
@@ -240,5 +240,5 @@ class LCMSchedulerTest(SchedulerCommonTest):
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
# TODO: get expected sum and mean
|
||||
assert abs(result_sum.item() - 280.5618) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3653) < 1e-3
|
||||
assert abs(result_sum.item() - 197.7616) < 1e-3
|
||||
assert abs(result_mean.item() - 0.2575) < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user