1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve LCMScheduler (#5681)

* Refactor LCMScheduler.step such that prev_sample == denoised at the last timestep in the schedule.

* Make timestep scaling when calculating boundary conditions configurable.

* Reparameterize timestep_scaling to be a multiplicative rather than division scaling.

* make style

* fix dtype conversion

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
dg845
2023-11-07 09:48:18 -08:00
committed by GitHub
parent 1dc231d14a
commit aab6de22c3
2 changed files with 18 additions and 10 deletions

View File

@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
timestep_scaling (`float`, defaults to 10.0):
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -208,6 +212,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -380,12 +385,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
def get_scalings_for_boundary_condition_discrete(self, t):
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.config.timestep_scaling
# By dividing 0.1: This is almost a delta function at t=0.
c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
@@ -466,9 +471,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used for one-step sampling.
if len(self.timesteps) > 1:
noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device)
# Noise is not used on the final timestep of the timestep schedule.
# This also means that noise is not used for one-step sampling.
if self.step_index != self.num_inference_steps - 1:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised

View File

@@ -230,7 +230,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
assert abs(result_sum.item() - 18.7097) < 1e-2
assert abs(result_sum.item() - 18.7097) < 1e-3
assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self):
@@ -240,5 +240,5 @@ class LCMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
assert abs(result_sum.item() - 280.5618) < 1e-2
assert abs(result_mean.item() - 0.3653) < 1e-3
assert abs(result_sum.item() - 197.7616) < 1e-3
assert abs(result_mean.item() - 0.2575) < 1e-3