1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into qwen-pipeline-mixin

This commit is contained in:
sayakpaul
2025-11-24 13:46:22 +05:30
3 changed files with 72 additions and 29 deletions

View File

@@ -73,7 +73,7 @@ EXAMPLE_DOC_STRING = """
"""
class BriaFiboPipeline(DiffusionPipeline):
class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
Args:
transformer (`BriaFiboTransformer2DModel`):

View File

@@ -488,9 +488,20 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
# Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
Args:
in_sigmas (`torch.Tensor`):
The input sigma values to be converted.
Returns:
`torch.Tensor`:
The converted sigma values following the Karras noise schedule.
"""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

View File

@@ -99,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
beta_end (`float`, defaults to `0.02`):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -118,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
timestep_spacing (`str`, defaults to `"linspace"`):
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -138,13 +137,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -183,7 +182,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> Union[float, torch.Tensor]:
"""
The standard deviation of the initial noise distribution.
Returns:
`float` or `torch.Tensor`:
The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
the timestep spacing configuration.
"""
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
@@ -191,21 +198,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
The index counter for current timestep. It will increase by 1 after each scheduler step.
Returns:
`int` or `None`:
The current step index, or `None` if not initialized.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Returns:
`int` or `None`:
The begin index for the scheduler, or `None` if not set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -239,14 +254,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def get_lms_coefficient(self, order, t, current_order):
def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
"""
Compute the linear multistep coefficient.
Args:
order ():
t ():
current_order ():
order (`int`):
The order of the linear multistep method.
t (`int`):
The current timestep index.
current_order (`int`):
The current order for which to compute the coefficient.
Returns:
`float`:
The computed linear multistep coefficient.
"""
def lms_derivative(tau):
@@ -261,7 +283,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -367,7 +389,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -403,9 +425,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
Args:
in_sigmas (`torch.Tensor`):
The input sigma values to be converted.
Returns:
`torch.Tensor`:
The converted sigma values following the Karras noise schedule.
"""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
@@ -629,5 +661,5 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps