diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 256eb8fee8..93f2f3a13a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -46,6 +46,7 @@ if is_torch_available(): DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5dab802ba8..01bcc6a338 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -129,8 +129,8 @@ def is_safetensors_compatible(info) -> bool: sf_filename = os.path.join(prefix, "model.safetensors") else: sf_filename = pt_filename[: -len(".bin")] + ".safetensors" - if sf_filename not in filenames: - logger.warning("{sf_filename} not found") + if is_safetensors_compatible and sf_filename not in filenames: + logger.warning(f"{sf_filename} not found") is_safetensors_compatible = False return is_safetensors_compatible @@ -767,7 +767,7 @@ class DiffusionPipeline(ConfigMixin): return pil_images - def progress_bar(self, iterable): + def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): @@ -775,7 +775,12 @@ class DiffusionPipeline(ConfigMixin): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) - return tqdm(iterable, **self._progress_bar_config) + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index dad4eb139a..4a80f2e689 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -541,25 +541,29 @@ class AltDiffusionPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 45df93fab0..16dbd626cd 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -433,7 +433,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -562,7 +562,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -574,25 +574,29 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 424f53d3f8..9ebbc249f6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -475,7 +475,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -607,7 +607,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -621,66 +621,70 @@ class CycleDiffusionPipeline(DiffusionPipeline): generator = extra_step_kwargs.pop("generator", None) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - source_latent_model_input = torch.cat([source_latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + source_latent_model_input = torch.cat([source_latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) - # predict the noise residual - concat_latent_model_input = torch.stack( - [ - source_latent_model_input[0], - latent_model_input[0], - source_latent_model_input[1], - latent_model_input[1], - ], - dim=0, - ) - concat_text_embeddings = torch.stack( - [ - source_text_embeddings[0], - text_embeddings[0], - source_text_embeddings[1], - text_embeddings[1], - ], - dim=0, - ) - concat_noise_pred = self.unet( - concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings - ).sample + # predict the noise residual + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_text_embeddings = torch.stack( + [ + source_text_embeddings[0], + text_embeddings[0], + source_text_embeddings[1], + text_embeddings[1], + ], + dim=0, + ) + concat_noise_pred = self.unet( + concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings + ).sample - # perform guidance - ( - source_noise_pred_uncond, - noise_pred_uncond, - source_noise_pred_text, - noise_pred_text, - ) = concat_noise_pred.chunk(4, dim=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( - source_noise_pred_text - source_noise_pred_uncond - ) + # perform guidance + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) - # Sample source_latents from the posterior distribution. - prev_source_latents = posterior_sample( - self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs - ) - # Compute noise. - noise = compute_noise( - self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs - ) - source_latents = prev_source_latents + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs - ).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2256477bb9..c6ff904d24 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -540,25 +540,29 @@ class StableDiffusionPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index fc30222a2e..e64a572a87 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -441,25 +441,29 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 4d645cc1f3..a25acc0bd1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -442,7 +442,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -571,7 +571,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -583,25 +583,29 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 67e67ebdf1..6cb2766bc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -655,7 +655,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps_tensor = self.scheduler.timesteps + timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -699,28 +699,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 11. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index b7356dc6db..2440b6d5ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -457,7 +457,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): init_image = init_image.to(device=self.device, dtype=dtype) @@ -577,7 +577,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -594,29 +594,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 10. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 7ccb43d46c..c9c238ce9a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -469,7 +469,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps_tensor = self.scheduler.timesteps + timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) @@ -511,30 +511,34 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, image], dim=1) + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) - # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level - ).sample + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level + ).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 776c37e4d8..7f08e40103 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -668,63 +668,71 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): safety_momentum = None - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) - noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] - # default classifier free guidance - noise_guidance = noise_pred_text - noise_pred_uncond + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond - # Perform SLD guidance - if enable_safety_guidance: - if safety_momentum is None: - safety_momentum = torch.zeros_like(noise_guidance) - noise_pred_safety_concept = noise_pred_out[2] + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] - # Equation 6 - scale = torch.clamp( - torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 - ) + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) - # Equation 6 - safety_concept_scale = torch.where( - (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale - ) + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) - # Equation 4 - noise_guidance_safety = torch.mul( - (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale - ) + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) - # Equation 7 - noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum - # Equation 8 - safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety - if i >= sld_warmup_steps: # Warmup - # Equation 3 - noise_guidance = noise_guidance - noise_guidance_safety + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety - noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6217bfcd69..d708963839 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,6 +22,7 @@ if is_torch_available(): from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_heun import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 7d9cef3152..a2e571f998 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -114,6 +114,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0112692a93..d1dfa1a44b 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -106,6 +106,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 76dc7acc1b..6258933dfe 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -118,6 +118,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 8117f30560..301ad2cebe 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -68,6 +68,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4b7b2909e7..10b0138abd 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -69,6 +69,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py new file mode 100644 index 0000000000..e6e5335e0d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -0,0 +1,247 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Args: + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + def index_for_timestep(self, timestep): + indices = (self.timesteps == timestep).nonzero() + if self.state_in_first_order: + pos = 0 if indices.shape[0] < 2 else 1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + @property + def state_in_first_order(self): + return self.dt is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + if self.state_in_first_order: + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + # 3. 1st order derivative + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 2. 2nd order / Heun's method + derivative = (sample - pred_original_sample) / sigma_hat + derivative = (self.prev_derivative + derivative) / 2 + + # 3. Retrieve 1st order derivative + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t) for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index e5495713a8..1bcebe65a3 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -37,6 +37,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps (`int`): number of diffusion steps used to train the model. """ + order = 1 + @register_to_config def __init__(self, num_train_timesteps: int = 1000): # set `betas`, `alphas`, `timesteps` diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index b2eb332aed..41a73b3ac3 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ + order = 2 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index cc9e8d7256..68deae8943 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -68,6 +68,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 8bf0a59582..e2a076925c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -90,6 +90,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 55625c1bfa..0b80181f43 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 1d436ab0cb..89d3d4a585 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 537d6f7e2a..5e4fe40229 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ + order = 1 + @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index 91c46e6554..89ba722a18 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): The ending cumulative gamma value. """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3dba3a2bc2..1c2e2c9abb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -83,6 +83,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ "PNDMScheduler", "LMSDiscreteScheduler", "EulerDiscreteScheduler", + "HeunDiscreteScheduler", "EulerAncestralDiscreteScheduler", "DPMSolverMultistepScheduler", ] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index af2e0c7c61..9846927cb1 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class HeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index e99f899669..229e18c69f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -928,7 +928,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): prompt = "astronaut riding a horse" generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") image = output.images[0] assert image.shape == (512, 512, 3) @@ -980,7 +980,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 51 + assert number_of_steps == 50 def test_stable_diffusion_low_cpu_mem_usage(self): pipeline_id = "CompVis/stable-diffusion-v1-4" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 9b350d42e1..0e5ebe0ec7 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -351,7 +351,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953]) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 elif step == 37: latents = latents.detach().cpu().numpy() assert latents.shape == (1, 4, 64, 64) @@ -386,7 +386,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 51 + assert number_of_steps == 50 def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index d86b259eae..0aa6e79cf8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 38 + assert number_of_steps == 37 def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index 4b972c7b7d..b719566b5e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 38 + assert number_of_steps == 37 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index fad3b89f05..cc77abca7c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -692,7 +692,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 21 + assert number_of_steps == 20 def test_stable_diffusion_low_cpu_mem_usage(self): pipeline_id = "stabilityai/stable-diffusion-2-base" diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index cfc450db4a..9b681f57a5 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -306,7 +306,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): image_slice = image[0, 253:256, 253:256, -1] assert image.shape == (1, 768, 768, 3) - expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_attention_slicing_v_pred(self): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 033f363ff4..6ae11e122d 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -654,7 +654,10 @@ class PipelineSlowTests(unittest.TestCase): force_download=True, ) - assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" + assert ( + cap_logger.out + == "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n" + ) def test_from_pretrained_save_pretrained(self): # 1. Load models diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6a76581632..f90246b337 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -30,6 +30,7 @@ from diffusers import ( DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, PNDMScheduler, @@ -1876,3 +1877,95 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest): def test_add_noise_device(self): pass + + +class HeunDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (HeunDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if str(torch_device).startswith("cpu"): + # The following sum varies between 148 and 156 on mps. Why? + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + elif str(torch_device).startswith("mps"): + # Larger tolerance on mps + assert abs(result_mean.item() - 0.0002) < 1e-2 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3