mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
allow tensors in several schedulers step() call (#8905)
This commit is contained in:
@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
|
||||
@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
|
||||
@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
|
||||
@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
|
||||
Reference in New Issue
Block a user