diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 685765779e..11073ce491 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 0f0e529605..4472a06c34 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, variance_noise: Optional[torch.Tensor] = None, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 4695941e72..6628a92ba0 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, variance_noise: Optional[torch.Tensor] = None, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index a5eb33d38a..1a10fff043 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, return_dict: bool = True, diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index 6eef247bfd..21ae1df00a 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, return_dict: bool = True, diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index 9f15f1fe0a..28f349ae21 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 367761ce3f..995f85c020 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: