diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index af80cf805a..111ccc40c5 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): def num_timesteps(self): return self._num_timesteps + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): + s = torch.tensor([0.008]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler ) + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + # 6. Run denoising loop - self._num_timesteps = len(timesteps[:-1]) - for i, t in enumerate(self.progress_bar(timesteps[:-1])): - timestep_ratio = t.expand(latents.size(0)).to(dtype) + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise latents predicted_latents = self.decoder( @@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t latents = self.scheduler.step( model_output=predicted_latents, timestep=timestep_ratio, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index dc6c81e1a8..058dbf6b07 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): return self._num_timesteps def get_timestep_ratio_conditioning(self, t, alphas_cumprod): - s = torch.tensor([0.003]) + s = torch.tensor([0.008]) clamp_range = [0, 1] min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 var = alphas_cumprod[t] @@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): if isinstance(self.scheduler, DDPMWuerstchenScheduler): timesteps = timesteps[:-1] else: - if self.scheduler.config.clip_sample: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: self.scheduler.config.clip_sample = False # disample sample clipping logger.warning(" set `clip_sample` to be False") # 6. Run denoising loop