1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

support edm dpmsolver multistep

This commit is contained in:
Aryan
2025-06-11 03:44:16 +02:00
parent 06e852de71
commit 2d0174063c
2 changed files with 38 additions and 9 deletions

View File

@@ -15,14 +15,35 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMDPMSolverMultistep
class EDMDPMSolverMultistepSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
@@ -593,7 +614,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
pred_original_sample: Optional[torch.Tensor] = None,
) -> Union[EDMDPMSolverMultistepSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
@@ -608,12 +630,14 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Whether or not to return a
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or a `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
@@ -634,7 +658,12 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
if pred_original_sample is None:
model_output = self.convert_model_output(model_output, sample=sample)
else:
model_output = pred_original_sample
# TODO: thresholding is not handled in this case, but probably not needed either for Cosmos
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
@@ -662,7 +691,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
return EDMDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, pred_original_sample=model_output)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(

View File

@@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMEuler
class EDMEulerSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.