mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix wrong param types, docs, and handles noise=None in scale_noise of FlowMatching schedulers (#11669)
* Bug: Fix wrong params, docs, and handles noise=None * make noise a required arg --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -171,8 +171,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
timestep: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
@@ -180,8 +180,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
timestep (`torch.FloatTensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
noise (`torch.FloatTensor`):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
|
||||
@@ -110,8 +110,8 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
timestep: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
@@ -119,8 +119,10 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
timestep (`torch.FloatTensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
noise (`torch.FloatTensor`):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
@@ -130,6 +132,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
@@ -192,8 +192,8 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
timestep: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
@@ -201,8 +201,10 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
timestep (`torch.FloatTensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
noise (`torch.FloatTensor`):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
|
||||
Reference in New Issue
Block a user