1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

allow tensors in several schedulers step() call (#8905)

This commit is contained in:
Pierre Chapuis
2024-07-20 06:58:06 +02:00
committed by GitHub
parent 461efc57c5
commit fe7948941d
7 changed files with 8 additions and 8 deletions

View File

@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.

View File

@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,

View File

@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,

View File

@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,

View File

@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,

View File

@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:

View File

@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: