mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support K-LMS in img2img (#270)
* Support K-LMS in img2img * Apply review suggestions
This commit is contained in:
@@ -5,7 +5,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@@ -87,12 +94,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
@@ -133,8 +145,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
latents = init_latents
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
t_index = t_start + i
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = latent_model_input.to(self.unet.dtype)
|
||||
t = t.to(self.unet.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
@@ -145,11 +164,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae.decode(latents.to(self.vae.dtype))
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -138,6 +138,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@@ -124,8 +124,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
sigmas = self.match_shape(self.sigmas, noise)
|
||||
noisy_samples = original_samples + noise * sigmas[timesteps]
|
||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||
noisy_samples = original_samples + noise * sigmas
|
||||
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user