mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Solve missing clip_sample implementation in FlaxDDIMScheduler. (#7017)
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * Fix bug for mention in this issue section #6901 * Update src/diffusers/schedulers/scheduling_ddim_flax.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix linter * Restore empty line --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -85,7 +85,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
trained_betas (`jnp.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
@@ -117,6 +119,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
clip_sample_range: float = 1.0,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
@@ -267,6 +271,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clip(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(state, timestep, prev_timestep)
|
||||
|
||||
Reference in New Issue
Block a user