mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove unused __init__ arg for scm scheduler
This commit is contained in:
@@ -300,20 +300,11 @@ def main(args):
|
||||
|
||||
# SCM Scheduler for Sana Sprint
|
||||
scheduler_config = {
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"clip_sample": True,
|
||||
"clip_sample_range": 1.0,
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "trigflow",
|
||||
"rescale_betas_zero_snr": False,
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": True,
|
||||
"steps_offset": 0,
|
||||
"thresholding": False,
|
||||
"timestep_spacing": "leading",
|
||||
"max_timesteps": 1.57080,
|
||||
"intermediate_timesteps": 1.3,
|
||||
"sigma_data": 0.5,
|
||||
}
|
||||
scheduler = SCMScheduler(**scheduler_config)
|
||||
pipe = SanaSprintPipeline(
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -57,45 +57,14 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Clip the predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, defaults to `True`):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
prediction_type (`str`, defaults to `trigflow`):
|
||||
Prediction type of the scheduler function. Currently only supports "trigflow".
|
||||
max_timesteps (`float`, defaults to 1.57080):
|
||||
The maximum timestep value used in the diffusion process.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used when num_inference_steps=2.
|
||||
sigma_data (`float`, defaults to 0.5):
|
||||
The standard deviation of the noise added during multi-step inference.
|
||||
"""
|
||||
|
||||
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -105,24 +74,26 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "trigflow",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: Optional[int] = 1.3,
|
||||
intermediate_timesteps: Optional[float] = 1.3,
|
||||
sigma_data: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize the SCM scheduler.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `trigflow`):
|
||||
Prediction type of the scheduler function. Currently only supports "trigflow".
|
||||
max_timesteps (`float`, defaults to 1.57080):
|
||||
The maximum timestep value used in the diffusion process.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used when num_inference_steps=2.
|
||||
sigma_data (`float`, defaults to 0.5):
|
||||
The standard deviation of the noise added during multi-step inference.
|
||||
"""
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user