1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

remove unused __init__ arg for scm scheduler

This commit is contained in:
yiyixuxu
2025-03-20 02:55:00 +01:00
parent 4eef82b2c9
commit 398ca0c938
2 changed files with 28 additions and 66 deletions

View File

@@ -300,20 +300,11 @@ def main(args):
# SCM Scheduler for Sana Sprint
scheduler_config = {
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"clip_sample": True,
"clip_sample_range": 1.0,
"dynamic_thresholding_ratio": 0.995,
"num_train_timesteps": 1000,
"prediction_type": "trigflow",
"rescale_betas_zero_snr": False,
"sample_max_value": 1.0,
"set_alpha_to_one": True,
"steps_offset": 0,
"thresholding": False,
"timestep_spacing": "leading",
"max_timesteps": 1.57080,
"intermediate_timesteps": 1.3,
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)
pipe = SanaSprintPipeline(

View File

@@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -57,45 +57,14 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
prediction_type (`str`, defaults to `trigflow`):
Prediction type of the scheduler function. Currently only supports "trigflow".
max_timesteps (`float`, defaults to 1.57080):
The maximum timestep value used in the diffusion process.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used when num_inference_steps=2.
sigma_data (`float`, defaults to 0.5):
The standard deviation of the noise added during multi-step inference.
"""
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -105,24 +74,26 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "trigflow",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
rescale_betas_zero_snr: bool = False,
max_timesteps: float = 1.57080,
intermediate_timesteps: Optional[int] = 1.3,
intermediate_timesteps: Optional[float] = 1.3,
sigma_data: float = 0.5,
):
"""
Initialize the SCM scheduler.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `trigflow`):
Prediction type of the scheduler function. Currently only supports "trigflow".
max_timesteps (`float`, defaults to 1.57080):
The maximum timestep value used in the diffusion process.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used when num_inference_steps=2.
sigma_data (`float`, defaults to 0.5):
The standard deviation of the noise added during multi-step inference.
"""
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0