mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add 2nd order heun scheduler (#1336)
* Add heun * Finish first version of heun * remove bogus * finish * finish * improve * up * up * fix more * change progress bar * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * finish * up * up * up
This commit is contained in:
committed by
GitHub
parent
25f11424f6
commit
4c54519e1a
@@ -46,6 +46,7 @@ if is_torch_available():
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
|
||||
@@ -129,8 +129,8 @@ def is_safetensors_compatible(info) -> bool:
|
||||
sf_filename = os.path.join(prefix, "model.safetensors")
|
||||
else:
|
||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||
if sf_filename not in filenames:
|
||||
logger.warning("{sf_filename} not found")
|
||||
if is_safetensors_compatible and sf_filename not in filenames:
|
||||
logger.warning(f"{sf_filename} not found")
|
||||
is_safetensors_compatible = False
|
||||
return is_safetensors_compatible
|
||||
|
||||
@@ -767,7 +767,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
return pil_images
|
||||
|
||||
def progress_bar(self, iterable):
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
@@ -775,7 +775,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
@@ -541,25 +541,29 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -433,7 +433,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -562,7 +562,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -574,25 +574,29 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -475,7 +475,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -607,7 +607,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -621,66 +621,70 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
generator = extra_step_kwargs.pop("generator", None)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
source_latent_model_input = torch.cat([source_latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
source_latent_model_input = torch.cat([source_latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
concat_latent_model_input = torch.stack(
|
||||
[
|
||||
source_latent_model_input[0],
|
||||
latent_model_input[0],
|
||||
source_latent_model_input[1],
|
||||
latent_model_input[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_text_embeddings = torch.stack(
|
||||
[
|
||||
source_text_embeddings[0],
|
||||
text_embeddings[0],
|
||||
source_text_embeddings[1],
|
||||
text_embeddings[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_noise_pred = self.unet(
|
||||
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
|
||||
).sample
|
||||
# predict the noise residual
|
||||
concat_latent_model_input = torch.stack(
|
||||
[
|
||||
source_latent_model_input[0],
|
||||
latent_model_input[0],
|
||||
source_latent_model_input[1],
|
||||
latent_model_input[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_text_embeddings = torch.stack(
|
||||
[
|
||||
source_text_embeddings[0],
|
||||
text_embeddings[0],
|
||||
source_text_embeddings[1],
|
||||
text_embeddings[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_noise_pred = self.unet(
|
||||
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
(
|
||||
source_noise_pred_uncond,
|
||||
noise_pred_uncond,
|
||||
source_noise_pred_text,
|
||||
noise_pred_text,
|
||||
) = concat_noise_pred.chunk(4, dim=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
||||
source_noise_pred_text - source_noise_pred_uncond
|
||||
)
|
||||
# perform guidance
|
||||
(
|
||||
source_noise_pred_uncond,
|
||||
noise_pred_uncond,
|
||||
source_noise_pred_text,
|
||||
noise_pred_text,
|
||||
) = concat_noise_pred.chunk(4, dim=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
||||
source_noise_pred_text - source_noise_pred_uncond
|
||||
)
|
||||
|
||||
# Sample source_latents from the posterior distribution.
|
||||
prev_source_latents = posterior_sample(
|
||||
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
|
||||
)
|
||||
# Compute noise.
|
||||
noise = compute_noise(
|
||||
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
|
||||
)
|
||||
source_latents = prev_source_latents
|
||||
# Sample source_latents from the posterior distribution.
|
||||
prev_source_latents = posterior_sample(
|
||||
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
|
||||
)
|
||||
# Compute noise.
|
||||
noise = compute_noise(
|
||||
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
|
||||
)
|
||||
source_latents = prev_source_latents
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
|
||||
).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -540,25 +540,29 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -441,25 +441,29 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -442,7 +442,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@@ -571,7 +571,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -583,25 +583,29 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -655,7 +655,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -699,28 +699,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 11. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -457,7 +457,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
|
||||
init_image = init_image.to(device=self.device, dtype=dtype)
|
||||
@@ -577,7 +577,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -594,29 +594,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -469,7 +469,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
@@ -511,30 +511,34 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
||||
).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
|
||||
@@ -668,63 +668,71 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
|
||||
safety_momentum = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * (3 if enable_safety_guidance else 2))
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
|
||||
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
|
||||
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
||||
|
||||
# default classifier free guidance
|
||||
noise_guidance = noise_pred_text - noise_pred_uncond
|
||||
# default classifier free guidance
|
||||
noise_guidance = noise_pred_text - noise_pred_uncond
|
||||
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
|
||||
# Equation 6
|
||||
scale = torch.clamp(
|
||||
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
|
||||
)
|
||||
# Equation 6
|
||||
scale = torch.clamp(
|
||||
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
|
||||
)
|
||||
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
|
||||
)
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
|
||||
torch.zeros_like(scale),
|
||||
scale,
|
||||
)
|
||||
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul(
|
||||
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
|
||||
)
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul(
|
||||
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
|
||||
)
|
||||
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
|
||||
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
|
||||
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# call the callback, if provided
|
||||
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||
from .scheduling_heun import HeunDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
|
||||
@@ -114,6 +114,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -106,6 +106,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -118,6 +118,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -68,6 +68,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -69,6 +69,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
247
src/diffusers/schedulers/scheduling_heun.py
Normal file
247
src/diffusers/schedulers/scheduling_heun.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Args:
|
||||
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
|
||||
k-diffusion implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
||||
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
def index_for_timestep(self, timestep):
|
||||
indices = (self.timesteps == timestep).nonzero()
|
||||
if self.state_in_first_order:
|
||||
pos = 0 if indices.shape[0] < 2 else 1
|
||||
else:
|
||||
pos = 0
|
||||
return indices[pos].item()
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
|
||||
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / Heun's method
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
# passing it to the model which requires a change in API
|
||||
gamma = 0
|
||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
# 3. 1st order derivative
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.prev_derivative = derivative
|
||||
self.dt = dt
|
||||
self.sample = sample
|
||||
else:
|
||||
# 2. 2nd order / Heun's method
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
derivative = (self.prev_derivative + derivative) / 2
|
||||
|
||||
# 3. Retrieve 1st order derivative
|
||||
dt = self.dt
|
||||
sample = self.sample
|
||||
|
||||
# free dt and derivative
|
||||
# Note, this puts the scheduler in "first order mode"
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
self.sample = None
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t) for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -37,6 +37,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
|
||||
@@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -68,6 +68,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -90,6 +90,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
correct_steps (`int`): number of correction steps performed on a produced sample.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
|
||||
self.sigmas = None
|
||||
|
||||
@@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
The ending cumulative gamma value.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -83,6 +83,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
@@ -980,7 +980,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
assert number_of_steps == 50
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
@@ -351,7 +351,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
||||
elif step == 37:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
@@ -386,7 +386,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
assert number_of_steps == 50
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
assert number_of_steps == 37
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
assert number_of_steps == 37
|
||||
|
||||
@@ -692,7 +692,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 21
|
||||
assert number_of_steps == 20
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "stabilityai/stable-diffusion-2-base"
|
||||
|
||||
@@ -306,7 +306,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_attention_slicing_v_pred(self):
|
||||
|
||||
@@ -654,7 +654,10 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
|
||||
)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -1876,3 +1877,95 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
|
||||
class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (HeunDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if str(torch_device).startswith("cpu"):
|
||||
# The following sum varies between 148 and 156 on mps. Why?
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
elif str(torch_device).startswith("mps"):
|
||||
# Larger tolerance on mps
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-2
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 0.1233) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0002) < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user