1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Custom sampler support for Stable Cascade Decoder (#9132)

Custom sampler support Stable Cascade Decoder
This commit is contained in:
Disty0
2024-08-20 22:56:53 +03:00
committed by GitHub
parent 214990e5f2
commit 21682bab7e
2 changed files with 37 additions and 5 deletions

View File

@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.008])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
var = alphas_cumprod[t]
var = var.clamp(*clamp_range)
s, min_var = s.to(var.device), min_var.to(var.device)
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return ratio
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
)
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
timesteps = timesteps[:-1]
else:
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
self.scheduler.config.clip_sample = False # disample sample clipping
logger.warning(" set `clip_sample` to be False")
# 6. Run denoising loop
self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
timestep_ratio = t.expand(latents.size(0)).to(dtype)
if hasattr(self.scheduler, "betas"):
alphas = 1.0 - self.scheduler.betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
else:
alphas_cumprod = []
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
else:
timestep_ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise latents
predicted_latents = self.decoder(
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
# 9. Renoise latents to next timestep
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
timestep_ratio = t
latents = self.scheduler.step(
model_output=predicted_latents,
timestep=timestep_ratio,

View File

@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
return self._num_timesteps
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.003])
s = torch.tensor([0.008])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
var = alphas_cumprod[t]
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
timesteps = timesteps[:-1]
else:
if self.scheduler.config.clip_sample:
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
self.scheduler.config.clip_sample = False # disample sample clipping
logger.warning(" set `clip_sample` to be False")
# 6. Run denoising loop