mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Custom sampler support for Stable Cascade Decoder (#9132)
Custom sampler support Stable Cascade Decoder
This commit is contained in:
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.008])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
var = var.clamp(*clamp_range)
|
||||
s, min_var = s.to(var.device), min_var.to(var.device)
|
||||
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return ratio
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
|
||||
# 6. Run denoising loop
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
if hasattr(self.scheduler, "betas"):
|
||||
alphas = 1.0 - self.scheduler.betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
else:
|
||||
alphas_cumprod = []
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
if len(alphas_cumprod) > 0:
|
||||
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
||||
else:
|
||||
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
||||
else:
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timestep_ratio = t
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_latents,
|
||||
timestep=timestep_ratio,
|
||||
|
||||
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.003])
|
||||
s = torch.tensor([0.008])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if self.scheduler.config.clip_sample:
|
||||
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
# 6. Run denoising loop
|
||||
|
||||
Reference in New Issue
Block a user