diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index 2162994dab..01f3d464aa 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -15,14 +15,35 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMDPMSolverMultistep +class EDMDPMSolverMultistepSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): @@ -593,7 +614,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): sample: torch.Tensor, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + pred_original_sample: Optional[torch.Tensor] = None, + ) -> Union[EDMDPMSolverMultistepSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. @@ -608,12 +630,14 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Whether or not to return a + [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or a `tuple`. Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: @@ -634,7 +658,12 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) - model_output = self.convert_model_output(model_output, sample=sample) + if pred_original_sample is None: + model_output = self.convert_model_output(model_output, sample=sample) + else: + model_output = pred_original_sample + # TODO: thresholding is not handled in this case, but probably not needed either for Cosmos + for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -662,7 +691,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return EDMDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, pred_original_sample=model_output) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index 24a8044f11..2e3bae6cbc 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMEuler class EDMEulerSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output.