mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add stochastic sampling to FlowMatchEulerDiscreteScheduler (#11369)
* Add stochastic sampling to FlowMatchEulerDiscreteScheduler
This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on b1aeddd7cc ltx_video/schedulers/rf.py
* Apply style fixes
* Use config value directly
* Apply style fixes
* Swap order
* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
---------
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
||||
time_shift_type (`str`, defaults to "exponential"):
|
||||
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
||||
stochastic_sampling (`bool`, defaults to False):
|
||||
Whether to use stochastic sampling.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
time_shift_type: str = "exponential",
|
||||
stochastic_sampling: bool = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
|
||||
current_sigma = per_token_sigmas[..., None]
|
||||
next_sigma = lower_sigmas[..., None]
|
||||
dt = current_sigma - next_sigma
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma_idx = self.step_index
|
||||
sigma = self.sigmas[sigma_idx]
|
||||
sigma_next = self.sigmas[sigma_idx + 1]
|
||||
|
||||
current_sigma = sigma
|
||||
next_sigma = sigma_next
|
||||
dt = sigma_next - sigma
|
||||
|
||||
prev_sample = sample + dt * model_output
|
||||
if self.config.stochastic_sampling:
|
||||
x0 = sample - current_sigma * model_output
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||
else:
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
Reference in New Issue
Block a user