1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add stochastic sampling to FlowMatchEulerDiscreteScheduler (#11369)

* Add stochastic sampling to FlowMatchEulerDiscreteScheduler

This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on b1aeddd7cc  ltx_video/schedulers/rf.py

* Apply style fixes

* Use config value directly

* Apply style fixes

* Swap order

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
apolinário
2025-04-22 05:18:30 +02:00
committed by GitHub
parent f59df3bb8b
commit 6ab62c7431

View File

@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""
_compatibles = []
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = current_sigma - next_sigma
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]
current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma
prev_sample = sample + dt * model_output
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1