From 46fac824be8837abbc6348247fddf3ca53251047 Mon Sep 17 00:00:00 2001 From: Chi Date: Fri, 8 Mar 2024 16:31:59 +0530 Subject: [PATCH] Solve missing clip_sample implementation in FlaxDDIMScheduler. (#7017) * I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * Fix bug for mention in this issue section #6901 * Update src/diffusers/schedulers/scheduling_ddim_flax.py Co-authored-by: Pedro Cuenca * Fix linter * Restore empty line --------- Co-authored-by: Sayak Paul Co-authored-by: Pedro Cuenca --- src/diffusers/schedulers/scheduling_ddim_flax.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 5f89119e26..0db2d61bf6 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -85,7 +85,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. + option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, @@ -117,6 +119,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, + clip_sample: bool = True, + clip_sample_range: float = 1.0, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", @@ -267,6 +271,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): " `v_prediction`" ) + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clip( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(state, timestep, prev_timestep)