From efa773afd2a99f6041043298d9f3e8bcdaa325c7 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 29 Aug 2022 17:17:05 +0200 Subject: [PATCH] Support K-LMS in img2img (#270) * Support K-LMS in img2img * Apply review suggestions --- examples/inference/image_to_image.py | 34 +++++++++++++++---- .../pipeline_stable_diffusion.py | 1 + .../schedulers/scheduling_lms_discrete.py | 5 +-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/examples/inference/image_to_image.py b/examples/inference/image_to_image.py index e5f34ad3df..fbcb5e338c 100644 --- a/examples/inference/image_to_image.py +++ b/examples/inference/image_to_image.py @@ -5,7 +5,14 @@ import numpy as np import torch import PIL -from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -87,12 +94,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) # get prompt text embeddings text_input = self.tokenizer( @@ -133,8 +145,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + t_index = t_start + i # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] @@ -145,11 +164,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents) + image = self.vae.decode(latents.to(self.vae.dtype)) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f0b353d931..fca8715118 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -138,6 +138,7 @@ class StableDiffusionPipeline(DiffusionPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9ac39c79d5..e6adcaac58 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -124,8 +124,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): return {"prev_sample": prev_sample} def add_noise(self, original_samples, noise, timesteps): - sigmas = self.match_shape(self.sigmas, noise) - noisy_samples = original_samples + noise * sigmas[timesteps] + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + return noisy_samples def __len__(self):