diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 9fd61d9e18..378a62ca8a 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -171,8 +171,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -180,8 +180,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py index 6febee444c..6b85194f8b 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -110,8 +110,8 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -119,8 +119,10 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: @@ -130,6 +132,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): self._init_step_index(timestep) sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample return sample diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py index 25186d1fe9..8ef0e2ec81 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py +++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py @@ -192,8 +192,8 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -201,8 +201,10 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: