1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Flag Flax schedulers as deprecated (#13031)

flag flax schedulers as deprecated
This commit is contained in:
David El Malih
2026-01-26 18:41:48 +01:00
committed by GitHub
parent 2af7baa040
commit 956bdcc3ea
6 changed files with 141 additions and 24 deletions

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
common: CommonSchedulerState
@@ -171,6 +175,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
@@ -203,7 +211,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def set_timesteps(
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
self,
state: DPMSolverMultistepSchedulerState,
num_inference_steps: int,
shape: Tuple,
) -> DPMSolverMultistepSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -301,10 +312,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
if self.config.thresholding:
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
dynamic_max_val = jnp.percentile(
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
jnp.abs(x0_pred),
self.config.dynamic_thresholding_ratio,
axis=tuple(range(1, x0_pred.ndim)),
)
dynamic_max_val = jnp.maximum(
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
dynamic_max_val,
self.config.sample_max_value * jnp.ones_like(dynamic_max_val),
)
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
return x0_pred
@@ -385,7 +399,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
lambda_t, lambda_s0, lambda_s1 = (
state.lambda_t[t],
state.lambda_t[s0],
state.lambda_t[s1],
)
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
@@ -443,7 +461,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
t, s0, s1, s2 = (
prev_timestep,
timestep_list[-1],
timestep_list[-2],
timestep_list[-3],
)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
state.lambda_t[t],
@@ -615,7 +638,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
def scale_model_input(
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DPMSolverMultistepSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the

View File

@@ -19,6 +19,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -28,6 +29,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class EulerDiscreteSchedulerState:
common: CommonSchedulerState
@@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState:
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
sigmas: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
return cls(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
@dataclass
@@ -99,6 +112,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
@@ -146,7 +163,10 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
return sample
def set_timesteps(
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
self,
state: EulerDiscreteSchedulerState,
num_inference_steps: int,
shape: Tuple = (),
) -> EulerDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -159,7 +179,12 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
if self.config.timestep_spacing == "linspace":
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
timesteps = jnp.linspace(
self.config.num_train_timesteps - 1,
0,
num_inference_steps,
dtype=self.dtype,
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)

View File

@@ -22,10 +22,13 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .scheduling_utils_flax import FlaxSchedulerMixin
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class KarrasVeSchedulerState:
# setable values
@@ -102,7 +105,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
s_min: float = 0.05,
s_max: float = 50,
):
pass
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
def create_state(self):
return KarrasVeSchedulerState.create()

View File

@@ -20,6 +20,7 @@ import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -29,6 +30,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
common: CommonSchedulerState
@@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState:
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
sigmas: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
return cls(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
@dataclass
@@ -101,6 +114,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
@@ -165,7 +182,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
return integrated_coeff
def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
self,
state: LMSDiscreteSchedulerState,
num_inference_steps: int,
shape: Tuple = (),
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -177,7 +197,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
timesteps = jnp.linspace(
self.config.num_train_timesteps - 1,
0,
num_inference_steps,
dtype=self.dtype,
)
low_idx = jnp.floor(timesteps).astype(jnp.int32)
high_idx = jnp.ceil(timesteps).astype(jnp.int32)

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -31,6 +32,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class PNDMSchedulerState:
common: CommonSchedulerState
@@ -131,6 +135,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
# For now we only support F-PNDM, i.e. the runge-kutta method
@@ -190,7 +198,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
else:
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
jnp.array(
[0, self.config.num_train_timesteps // num_inference_steps // 2],
dtype=jnp.int32,
),
self.pndm_order,
)
@@ -218,7 +229,10 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: PNDMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -320,7 +334,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
state.counter % 2,
0,
self.config.num_train_timesteps // state.num_inference_steps // 2,
)
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
@@ -401,7 +417,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
state.counter == 1,
timestep + self.config.num_train_timesteps // state.num_inference_steps,
timestep,
)
# Reference:
@@ -466,7 +484,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
# prev_sample -> x_(tδ)
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
prev_timestep >= 0,
state.common.alphas_cumprod[prev_timestep],
state.final_alpha_cumprod,
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -23,7 +23,15 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from ..utils import logging
from .scheduling_utils_flax import (
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
@@ -95,7 +103,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
pass
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
def create_state(self):
state = ScoreSdeVeSchedulerState.create()
@@ -108,7 +119,11 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
self,
state: ScoreSdeVeSchedulerState,
num_inference_steps: int,
shape: Tuple = (),
sampling_eps: float = None,
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.