1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Flax] Stateless schedulers, fixes and refactors (#1661)

* [Flax] Stateless schedulers, fixes and refactors

* Remove scheduling_common_flax and some renames

* Update src/diffusers/schedulers/scheduling_pndm_flax.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Simon Kirsten
2022-12-20 00:42:41 +00:00
committed by GitHub
parent d87cc15977
commit f106ab40b3
12 changed files with 634 additions and 552 deletions

View File

@@ -475,6 +475,7 @@ def main():
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()
# Initialize our training
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -511,7 +512,7 @@ def main():
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
# Get the text embedding for conditioning
if args.train_text_encoder:

View File

@@ -417,6 +417,7 @@ def main():
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()
# Initialize our training
rng = jax.random.PRNGKey(args.seed)
@@ -449,7 +450,7 @@ def main():
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(

View File

@@ -505,6 +505,7 @@ def main():
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()
# Initialize our training
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -531,7 +532,7 @@ def main():
0,
noise_scheduler.config.num_train_timesteps,
)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
encoder_hidden_states = state.apply_fn(
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
)[0]

View File

@@ -261,7 +261,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = latents * params["scheduler"].init_noise_sigma
if DEBUG:
# run with python for loop
for i in range(num_inference_steps):

View File

@@ -15,7 +15,6 @@
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -26,51 +25,37 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
final_alpha_cumprod: jnp.ndarray
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
def create(
cls,
common: CommonSchedulerState,
final_alpha_cumprod: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
@dataclass
@@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@property
def has_state(self):
return True
@@ -129,43 +117,46 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.dtype = dtype
self.alphas = 1.0 - self.betas
# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
final_alpha_cumprod = (
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
return DDIMSchedulerState.create(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
@@ -181,21 +172,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
return sample
def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
)
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
@@ -208,15 +184,27 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = timesteps + offset
# rounding to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
)
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def step(
self,
@@ -224,6 +212,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
eta: float = 0.0,
return_dict: bool = True,
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
"""
@@ -259,17 +248,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta = 0.0
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
alphas_cumprod = state.alphas_cumprod
alphas_cumprod = state.common.alphas_cumprod
final_alpha_cumprod = state.final_alpha_cumprod
# 2. compute alphas, betas
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
@@ -291,7 +278,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
variance = self._get_variance(state, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
@@ -307,20 +294,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
def add_noise(
self,
state: DDIMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -14,62 +14,36 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDPMSchedulerState:
common: CommonSchedulerState
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
@dataclass
@@ -106,11 +80,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@property
def has_state(self):
return True
@@ -126,35 +104,47 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.dtype = dtype
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
# standard deviation of the initial noise distribution
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
return DDPMSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
def scale_model_input(
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
@@ -168,20 +158,25 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
timesteps = jnp.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps
)[::-1]
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
)
def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None):
alpha_prod_t = state.common.alphas_cumprod[t]
alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t]
if variance_type is None:
variance_type = self.config.variance_type
@@ -193,15 +188,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
elif variance_type == "fixed_small_log":
variance = jnp.log(jnp.clip(variance, a_min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
variance = state.common.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = jnp.log(self.betas[t])
variance = jnp.log(state.common.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
max_log = state.common.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
@@ -213,9 +208,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
key: jax.random.KeyArray = jax.random.PRNGKey(0),
return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@@ -227,7 +221,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
key (`jax.random.KeyArray`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
@@ -235,16 +229,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
@@ -253,8 +237,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
alpha_prod_t = state.common.alphas_cumprod[t]
alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -264,6 +248,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
@@ -276,19 +262,20 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t
current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
def random_variance():
split_key = jax.random.split(key, num=1)
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype))
pred_prev_sample = pred_prev_sample + variance
@@ -299,20 +286,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
def add_noise(
self,
state: DDPMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -14,7 +14,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@@ -26,57 +25,49 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
common: CommonSchedulerState
alpha_t: jnp.ndarray
sigma_t: jnp.ndarray
lambda_t: jnp.ndarray
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
# running values
model_outputs: Optional[jnp.ndarray] = None
lower_order_nums: Optional[int] = None
step_index: Optional[int] = None
prev_timestep: Optional[int] = None
lower_order_nums: Optional[jnp.int32] = None
prev_timestep: Optional[jnp.int32] = None
cur_sample: Optional[jnp.ndarray] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
def create(
cls,
common: CommonSchedulerState,
alpha_t: jnp.ndarray,
sigma_t: jnp.ndarray,
lambda_t: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
alpha_t=alpha_t,
sigma_t=sigma_t,
lambda_t=lambda_t,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
@dataclass
@@ -145,12 +136,15 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@property
def has_state(self):
return True
@@ -171,47 +165,47 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# Currently we only support VP-type noise schedule
self.alpha_t = jnp.sqrt(self.alphas_cumprod)
self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod)
self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
alpha_t = jnp.sqrt(common.alphas_cumprod)
sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}")
if self.config.solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}")
def create_state(self):
return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
# standard deviation of the initial noise distribution
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
return DPMSolverMultistepSchedulerState.create(
common=common,
alpha_t=alpha_t,
sigma_t=sigma_t,
lambda_t=lambda_t,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
def set_timesteps(
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
@@ -227,24 +221,32 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
shape (`Tuple`):
the shape of the samples to be generated.
"""
timesteps = (
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.astype(jnp.int32)
)
# initial running values
model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype)
lower_order_nums = jnp.int32(0)
prev_timestep = jnp.int32(-1)
cur_sample = jnp.zeros(shape, dtype=self.dtype)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
model_outputs=jnp.zeros((self.config.solver_order,) + shape),
lower_order_nums=0,
step_index=0,
prev_timestep=-1,
cur_sample=jnp.zeros(shape),
model_outputs=model_outputs,
lower_order_nums=lower_order_nums,
prev_timestep=prev_timestep,
cur_sample=cur_sample,
)
def convert_model_output(
self,
state: DPMSolverMultistepSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
@@ -271,12 +273,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
@@ -299,11 +301,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
@@ -313,7 +315,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def dpm_solver_first_order_update(
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
self,
state: DPMSolverMultistepSchedulerState,
model_output: jnp.ndarray,
timestep: int,
prev_timestep: int,
sample: jnp.ndarray,
) -> jnp.ndarray:
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -332,9 +339,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
t, s0 = prev_timestep, timestep
m0 = model_output
lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0]
lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
@@ -344,6 +351,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
state: DPMSolverMultistepSchedulerState,
model_output_list: jnp.ndarray,
timestep_list: List[int],
prev_timestep: int,
@@ -365,9 +373,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
@@ -403,6 +411,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
state: DPMSolverMultistepSchedulerState,
model_output_list: jnp.ndarray,
timestep_list: List[int],
prev_timestep: int,
@@ -425,13 +434,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
state.lambda_t[t],
state.lambda_t[s0],
state.lambda_t[s1],
state.lambda_t[s2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
@@ -482,14 +491,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
prev_timestep = jax.lax.cond(
state.step_index == len(state.timesteps) - 1,
lambda _: 0,
lambda _: state.timesteps[state.step_index + 1],
(),
)
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
model_output = self.convert_model_output(model_output, timestep, sample)
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1])
model_output = self.convert_model_output(state, model_output, timestep, sample)
model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
model_outputs_new = model_outputs_new.at[-1].set(model_output)
@@ -501,16 +513,18 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
return self.dpm_solver_first_order_update(
state,
state.model_outputs[-1],
state.timesteps[state.step_index],
state.timesteps[step_index],
state.prev_timestep,
state.cur_sample,
)
def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]])
timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]])
return self.multistep_dpm_solver_second_order_update(
state,
state.model_outputs,
timestep_list,
state.prev_timestep,
@@ -520,65 +534,67 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
timestep_list = jnp.array(
[
state.timesteps[state.step_index - 2],
state.timesteps[state.step_index - 1],
state.timesteps[state.step_index],
state.timesteps[step_index - 2],
state.timesteps[step_index - 1],
state.timesteps[step_index],
]
)
return self.multistep_dpm_solver_third_order_update(
state,
state.model_outputs,
timestep_list,
state.prev_timestep,
state.cur_sample,
)
step_2_output = step_2(state)
step_3_output = step_3(state)
if self.config.solver_order == 2:
return step_2(state)
return step_2_output
elif self.config.lower_order_final and len(state.timesteps) < 15:
return jax.lax.cond(
return jax.lax.select(
state.lower_order_nums < 2,
step_2,
lambda state: jax.lax.cond(
state.step_index == len(state.timesteps) - 2,
step_2,
step_3,
state,
step_2_output,
jax.lax.select(
step_index == len(state.timesteps) - 2,
step_2_output,
step_3_output,
),
state,
)
else:
return jax.lax.cond(
return jax.lax.select(
state.lower_order_nums < 2,
step_2,
step_3,
state,
step_2_output,
step_3_output,
)
step_1_output = step_1(state)
step_23_output = step_23(state)
if self.config.solver_order == 1:
prev_sample = step_1(state)
prev_sample = step_1_output
elif self.config.lower_order_final and len(state.timesteps) < 15:
prev_sample = jax.lax.cond(
prev_sample = jax.lax.select(
state.lower_order_nums < 1,
step_1,
lambda state: jax.lax.cond(
state.step_index == len(state.timesteps) - 1,
step_1,
step_23,
state,
step_1_output,
jax.lax.select(
step_index == len(state.timesteps) - 1,
step_1_output,
step_23_output,
),
state,
)
else:
prev_sample = jax.lax.cond(
prev_sample = jax.lax.select(
state.lower_order_nums < 1,
step_1,
step_23,
state,
step_1_output,
step_23_output,
)
state = state.replace(
lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
step_index=(state.step_index + 1),
)
if not return_dict:
@@ -606,20 +622,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
def add_noise(
self,
state: DPMSolverMultistepSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -233,5 +233,5 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
def add_noise(self, original_samples, noise, timesteps):
def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps):
raise NotImplementedError()

View File

@@ -22,6 +22,7 @@ from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
@@ -30,15 +31,22 @@ from .scheduling_utils_flax import (
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
common: CommonSchedulerState
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
sigmas: jnp.ndarray
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
sigmas: Optional[jnp.ndarray] = None
derivatives: jnp.ndarray = jnp.array([])
# running values
derivatives: Optional[jnp.ndarray] = None
@classmethod
def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
@dataclass
@@ -66,10 +74,18 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
dtype: jnp.dtype
@property
def has_state(self):
return True
@@ -82,24 +98,26 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.dtype = dtype
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
def create_state(self):
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
init_noise_sigma = sigmas.max()
return LMSDiscreteSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
@@ -118,11 +136,13 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
`jnp.ndarray`: scaled input sample
"""
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def get_lms_coefficient(self, state, order, t, current_order):
def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
"""
Compute a linear multistep coefficient.
@@ -156,20 +176,28 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)
low_idx = jnp.floor(timesteps).astype(int)
high_idx = jnp.ceil(timesteps).astype(int)
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
low_idx = jnp.floor(timesteps).astype(jnp.int32)
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
frac = jnp.mod(timesteps, 1.0)
sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
timesteps = timesteps.astype(jnp.int32)
# initial running values
derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps.astype(int),
derivatives=jnp.array([]),
timesteps=timesteps,
sigmas=sigmas,
num_inference_steps=num_inference_steps,
derivatives=derivatives,
)
def step(
@@ -199,10 +227,23 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
sigma = state.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma

View File

@@ -14,7 +14,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -25,59 +24,45 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class PNDMSchedulerState:
common: CommonSchedulerState
final_alpha_cumprod: jnp.ndarray
# setable values
_timesteps: jnp.ndarray
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None
# running values
cur_model_output: Optional[jnp.ndarray] = None
counter: int = 0
counter: Optional[jnp.int32] = None
cur_sample: Optional[jnp.ndarray] = None
ets: jnp.ndarray = jnp.array([])
ets: Optional[jnp.ndarray] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
def create(
cls,
common: CommonSchedulerState,
final_alpha_cumprod: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
@dataclass
@@ -117,10 +102,19 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
dtype: jnp.dtype
pndm_order: int
@property
def has_state(self):
return True
@@ -136,35 +130,39 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.dtype = dtype
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
final_alpha_cumprod = (
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
)
# standard deviation of the initial noise distribution
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
return PNDMSchedulerState.create(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
"""
@@ -178,42 +176,47 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
shape (`Tuple`):
the shape of the samples to be generated.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
state = state.replace(
prk_timesteps=jnp.array([]),
plms_timesteps=jnp.concatenate(
[state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]]
)[::-1],
)
prk_timesteps = jnp.array([], dtype=jnp.int32)
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
else:
prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
self.pndm_order,
)
state = state.replace(
prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1],
plms_timesteps=state._timesteps[:-3][::-1],
)
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
plms_timesteps = _timesteps[:-3][::-1]
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
# initial running values
cur_model_output = jnp.zeros(shape, dtype=self.dtype)
counter = jnp.int32(0)
cur_sample = jnp.zeros(shape, dtype=self.dtype)
ets = jnp.zeros((4,) + shape, dtype=self.dtype)
return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
cur_sample=jnp.zeros(shape),
ets=jnp.zeros((4,) + shape),
timesteps=timesteps,
num_inference_steps=num_inference_steps,
prk_timesteps=prk_timesteps,
plms_timesteps=plms_timesteps,
cur_model_output=cur_model_output,
counter=counter,
cur_sample=cur_sample,
ets=ets,
)
def scale_model_input(
@@ -260,19 +263,27 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(state, model_output, timestep, sample)
else:
prev_sample, state = jax.lax.switch(
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
(self.step_prk, self.step_plms),
# Args to either branch
state,
model_output,
timestep,
sample,
prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample)
plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample)
cond = state.counter < len(state.prk_timesteps)
prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample)
state = state.replace(
cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output),
ets=jax.lax.select(cond, prk_state.ets, plms_state.ets),
cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample),
counter=jax.lax.select(cond, prk_state.counter, plms_state.counter),
)
if not return_dict:
@@ -304,6 +315,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
@@ -315,37 +327,34 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return (
state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.at[ets_at].set(model_output),
cur_sample=sample,
),
model_output,
)
model_output = jax.lax.select(
(state.counter % 4) != 3,
model_output, # remainder 0, 1, 2
state.cur_model_output + 1 / 6 * model_output, # remainder 3
)
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
model_output = state.cur_model_output + 1 / 6 * model_output
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
state, model_output = jax.lax.switch(
state.counter % 4,
(remainder_0, remainder_1, remainder_2, remainder_3),
# Args to either branch
state,
model_output,
state.counter // 4,
state = state.replace(
cur_model_output=jax.lax.select_n(
state.counter % 4,
state.cur_model_output + 1 / 6 * model_output, # remainder 0
state.cur_model_output + 1 / 3 * model_output, # remainder 1
state.cur_model_output + 1 / 3 * model_output, # remainder 2
jnp.zeros_like(state.cur_model_output), # remainder 3
),
ets=jax.lax.select(
(state.counter % 4) == 0,
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0
state.ets, # remainder 1, 2, 3
),
cur_sample=jax.lax.select(
(state.counter % 4) == 0,
sample, # remainder 0
state.cur_sample, # remainder 1, 2, 3
),
)
cur_sample = state.cur_sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
return (prev_sample, state)
@@ -374,18 +383,13 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(state.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
# NOTE: There is no way to check in the jitted runtime if the prk mode was ran before
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
@@ -417,64 +421,39 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
# else:
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
def counter_0(state: PNDMSchedulerState):
ets = state.ets.at[0].set(model_output)
return state.replace(
ets=ets,
cur_sample=sample,
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
)
state = state.replace(
ets=jax.lax.select(
state.counter != 1,
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1
state.ets, # counter 1
),
cur_sample=jax.lax.select(
state.counter != 1,
sample, # counter != 1
state.cur_sample, # counter 1
),
)
def counter_1(state: PNDMSchedulerState):
return state.replace(
cur_model_output=(model_output + state.ets[0]) / 2,
)
def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(3 * ets[1] - ets[0]) / 2,
cur_sample=sample,
)
def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
cur_sample=sample,
)
def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])
return state.replace(
ets=ets,
cur_model_output=next_model_output,
cur_sample=sample,
)
counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
state = state.replace(
cur_model_output=jax.lax.select_n(
jnp.clip(state.counter, 0, 4),
model_output, # counter 0
(model_output + state.ets[-1]) / 2, # counter 1
(3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2
(23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3
(1 / 24)
* (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4
),
)
sample = state.cur_sample
model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
return (prev_sample, state)
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
@@ -487,11 +466,20 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
if self.config.prediction_type == "v_prediction":
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
elif self.config.prediction_type != "epsilon":
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
# corresponds to (α_(tδ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(tδ) - α_t) / (sqrt(α_t) * (sqrt(α_(tδ)) + sqr(α_t))) =
@@ -512,20 +500,12 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
def add_noise(
self,
state: PNDMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
@@ -50,6 +52,7 @@ class FlaxSchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["dtype"]
_compatibles = []
has_compatibles = True
@@ -167,3 +170,90 @@ class FlaxSchedulerMixin:
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=dtype)
@flax.struct.dataclass
class CommonSchedulerState:
alphas: jnp.ndarray
betas: jnp.ndarray
alphas_cumprod: jnp.ndarray
@classmethod
def create(cls, scheduler):
config = scheduler.config
if config.trained_betas is not None:
betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
elif config.beta_schedule == "linear":
betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
elif config.beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
betas = (
jnp.linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
)
** 2
)
elif config.beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
else:
raise NotImplementedError(
f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
)
alphas = 1.0 - betas
alphas_cumprod = jnp.cumprod(alphas, axis=0)
return cls(
alphas=alphas,
betas=betas,
alphas_cumprod=alphas_cumprod,
)
def add_noise_common(
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
alphas_cumprod = state.alphas_cumprod
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -296,10 +296,11 @@ class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@@ -577,12 +578,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
sample = self.full_loop()