mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flax] Stateless schedulers, fixes and refactors (#1661)
* [Flax] Stateless schedulers, fixes and refactors * Remove scheduling_common_flax and some renames * Update src/diffusers/schedulers/scheduling_pndm_flax.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -475,6 +475,7 @@ def main():
|
||||
noise_scheduler = FlaxDDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
noise_scheduler_state = noise_scheduler.create_state()
|
||||
|
||||
# Initialize our training
|
||||
train_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
@@ -511,7 +512,7 @@ def main():
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
if args.train_text_encoder:
|
||||
|
||||
@@ -417,6 +417,7 @@ def main():
|
||||
noise_scheduler = FlaxDDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
noise_scheduler_state = noise_scheduler.create_state()
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(args.seed)
|
||||
@@ -449,7 +450,7 @@ def main():
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(
|
||||
|
||||
@@ -505,6 +505,7 @@ def main():
|
||||
noise_scheduler = FlaxDDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
noise_scheduler_state = noise_scheduler.create_state()
|
||||
|
||||
# Initialize our training
|
||||
train_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
@@ -531,7 +532,7 @@ def main():
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
|
||||
encoder_hidden_states = state.apply_fn(
|
||||
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
|
||||
)[0]
|
||||
|
||||
@@ -261,7 +261,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -26,51 +25,37 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
add_noise_common,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=jnp.float32)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDIMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
final_alpha_cumprod: jnp.ndarray
|
||||
|
||||
# setable values
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
alphas_cumprod: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
final_alpha_cumprod: jnp.ndarray,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(
|
||||
common=common,
|
||||
final_alpha_cumprod=final_alpha_cumprod,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -129,43 +117,46 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
self.dtype = dtype
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
|
||||
# HACK for now - clean up later (PVP)
|
||||
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
|
||||
final_alpha_cumprod = (
|
||||
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
|
||||
)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
||||
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
|
||||
return DDIMSchedulerState.create(
|
||||
common=common,
|
||||
final_alpha_cumprod=final_alpha_cumprod,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
@@ -181,21 +172,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def create_state(self):
|
||||
return DDIMSchedulerState.create(
|
||||
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
|
||||
)
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> DDIMSchedulerState:
|
||||
@@ -208,15 +184,27 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
offset = self.config.steps_offset
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
|
||||
timesteps = timesteps + offset
|
||||
# rounding to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset
|
||||
|
||||
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -224,6 +212,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
eta: float = 0.0,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -259,17 +248,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
|
||||
eta = 0.0
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
|
||||
alphas_cumprod = state.alphas_cumprod
|
||||
alphas_cumprod = state.common.alphas_cumprod
|
||||
final_alpha_cumprod = state.final_alpha_cumprod
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
@@ -291,7 +278,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
|
||||
variance = self._get_variance(state, timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
@@ -307,20 +294,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
state: DDIMSchedulerState,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -14,62 +14,36 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
add_noise_common,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=jnp.float32)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDPMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
|
||||
# setable values
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
|
||||
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -106,11 +80,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -126,35 +104,47 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
self.dtype = dtype
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
self.one = jnp.array(1.0)
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
def create_state(self):
|
||||
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
# standard deviation of the initial noise distribution
|
||||
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
||||
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
|
||||
return DDPMSchedulerState.create(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
@@ -168,20 +158,25 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
timesteps = jnp.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps
|
||||
)[::-1]
|
||||
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# rounding to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
|
||||
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = state.common.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t]
|
||||
|
||||
if variance_type is None:
|
||||
variance_type = self.config.variance_type
|
||||
@@ -193,15 +188,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
elif variance_type == "fixed_small_log":
|
||||
variance = jnp.log(jnp.clip(variance, a_min=1e-20))
|
||||
elif variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
variance = state.common.betas[t]
|
||||
elif variance_type == "fixed_large_log":
|
||||
# Glide max_log
|
||||
variance = jnp.log(self.betas[t])
|
||||
variance = jnp.log(state.common.betas[t])
|
||||
elif variance_type == "learned":
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
min_log = variance
|
||||
max_log = self.betas[t]
|
||||
max_log = state.common.betas[t]
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
@@ -213,9 +208,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
key: jax.random.KeyArray = jax.random.PRNGKey(0),
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
@@ -227,7 +221,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
key (`random.KeyArray`): a PRNG key.
|
||||
key (`jax.random.KeyArray`): a PRNG key.
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -235,16 +229,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
|
||||
@@ -253,8 +237,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
alpha_prod_t = state.common.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -264,6 +248,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
@@ -276,19 +262,20 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
|
||||
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t
|
||||
current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
key = random.split(key, num=1)
|
||||
noise = random.normal(key=key, shape=model_output.shape)
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
def random_variance():
|
||||
split_key = jax.random.split(key, num=1)
|
||||
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
|
||||
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
|
||||
variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype))
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
@@ -299,20 +286,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
state: DDPMSchedulerState,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -26,57 +25,49 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
add_noise_common,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=jnp.float32)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DPMSolverMultistepSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
alpha_t: jnp.ndarray
|
||||
sigma_t: jnp.ndarray
|
||||
lambda_t: jnp.ndarray
|
||||
|
||||
# setable values
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
timesteps: Optional[jnp.ndarray] = None
|
||||
|
||||
# running values
|
||||
model_outputs: Optional[jnp.ndarray] = None
|
||||
lower_order_nums: Optional[int] = None
|
||||
step_index: Optional[int] = None
|
||||
prev_timestep: Optional[int] = None
|
||||
lower_order_nums: Optional[jnp.int32] = None
|
||||
prev_timestep: Optional[jnp.int32] = None
|
||||
cur_sample: Optional[jnp.ndarray] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
alpha_t: jnp.ndarray,
|
||||
sigma_t: jnp.ndarray,
|
||||
lambda_t: jnp.ndarray,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(
|
||||
common=common,
|
||||
alpha_t=alpha_t,
|
||||
sigma_t=sigma_t,
|
||||
lambda_t=lambda_t,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -145,12 +136,15 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
||||
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -171,47 +165,47 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
# Currently we only support VP-type noise schedule
|
||||
self.alpha_t = jnp.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
alpha_t = jnp.sqrt(common.alphas_cumprod)
|
||||
sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
|
||||
lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}")
|
||||
if self.config.solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
def create_state(self):
|
||||
return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
# standard deviation of the initial noise distribution
|
||||
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
||||
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
|
||||
return DPMSolverMultistepSchedulerState.create(
|
||||
common=common,
|
||||
alpha_t=alpha_t,
|
||||
sigma_t=sigma_t,
|
||||
lambda_t=lambda_t,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
@@ -227,24 +221,32 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
shape (`Tuple`):
|
||||
the shape of the samples to be generated.
|
||||
"""
|
||||
|
||||
timesteps = (
|
||||
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.astype(jnp.int32)
|
||||
)
|
||||
|
||||
# initial running values
|
||||
|
||||
model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype)
|
||||
lower_order_nums = jnp.int32(0)
|
||||
prev_timestep = jnp.int32(-1)
|
||||
cur_sample = jnp.zeros(shape, dtype=self.dtype)
|
||||
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
model_outputs=jnp.zeros((self.config.solver_order,) + shape),
|
||||
lower_order_nums=0,
|
||||
step_index=0,
|
||||
prev_timestep=-1,
|
||||
cur_sample=jnp.zeros(shape),
|
||||
model_outputs=model_outputs,
|
||||
lower_order_nums=lower_order_nums,
|
||||
prev_timestep=prev_timestep,
|
||||
cur_sample=cur_sample,
|
||||
)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
@@ -271,12 +273,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -299,11 +301,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -313,7 +315,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
One step for the first-order DPM-Solver (equivalent to DDIM).
|
||||
@@ -332,9 +339,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t, s0 = prev_timestep, timestep
|
||||
m0 = model_output
|
||||
lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0]
|
||||
lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
|
||||
alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
|
||||
@@ -344,6 +351,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
model_output_list: jnp.ndarray,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
@@ -365,9 +373,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -403,6 +411,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
model_output_list: jnp.ndarray,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
@@ -425,13 +434,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
state.lambda_t[t],
|
||||
state.lambda_t[s0],
|
||||
state.lambda_t[s1],
|
||||
state.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -482,14 +491,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
prev_timestep = jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 1,
|
||||
lambda _: 0,
|
||||
lambda _: state.timesteps[state.step_index + 1],
|
||||
(),
|
||||
)
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
||||
step_index = step_index[0]
|
||||
|
||||
prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1])
|
||||
|
||||
model_output = self.convert_model_output(state, model_output, timestep, sample)
|
||||
|
||||
model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
|
||||
model_outputs_new = model_outputs_new.at[-1].set(model_output)
|
||||
@@ -501,16 +513,18 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
return self.dpm_solver_first_order_update(
|
||||
state,
|
||||
state.model_outputs[-1],
|
||||
state.timesteps[state.step_index],
|
||||
state.timesteps[step_index],
|
||||
state.prev_timestep,
|
||||
state.cur_sample,
|
||||
)
|
||||
|
||||
def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]])
|
||||
timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]])
|
||||
return self.multistep_dpm_solver_second_order_update(
|
||||
state,
|
||||
state.model_outputs,
|
||||
timestep_list,
|
||||
state.prev_timestep,
|
||||
@@ -520,65 +534,67 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
timestep_list = jnp.array(
|
||||
[
|
||||
state.timesteps[state.step_index - 2],
|
||||
state.timesteps[state.step_index - 1],
|
||||
state.timesteps[state.step_index],
|
||||
state.timesteps[step_index - 2],
|
||||
state.timesteps[step_index - 1],
|
||||
state.timesteps[step_index],
|
||||
]
|
||||
)
|
||||
return self.multistep_dpm_solver_third_order_update(
|
||||
state,
|
||||
state.model_outputs,
|
||||
timestep_list,
|
||||
state.prev_timestep,
|
||||
state.cur_sample,
|
||||
)
|
||||
|
||||
step_2_output = step_2(state)
|
||||
step_3_output = step_3(state)
|
||||
|
||||
if self.config.solver_order == 2:
|
||||
return step_2(state)
|
||||
return step_2_output
|
||||
elif self.config.lower_order_final and len(state.timesteps) < 15:
|
||||
return jax.lax.cond(
|
||||
return jax.lax.select(
|
||||
state.lower_order_nums < 2,
|
||||
step_2,
|
||||
lambda state: jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 2,
|
||||
step_2,
|
||||
step_3,
|
||||
state,
|
||||
step_2_output,
|
||||
jax.lax.select(
|
||||
step_index == len(state.timesteps) - 2,
|
||||
step_2_output,
|
||||
step_3_output,
|
||||
),
|
||||
state,
|
||||
)
|
||||
else:
|
||||
return jax.lax.cond(
|
||||
return jax.lax.select(
|
||||
state.lower_order_nums < 2,
|
||||
step_2,
|
||||
step_3,
|
||||
state,
|
||||
step_2_output,
|
||||
step_3_output,
|
||||
)
|
||||
|
||||
step_1_output = step_1(state)
|
||||
step_23_output = step_23(state)
|
||||
|
||||
if self.config.solver_order == 1:
|
||||
prev_sample = step_1(state)
|
||||
prev_sample = step_1_output
|
||||
|
||||
elif self.config.lower_order_final and len(state.timesteps) < 15:
|
||||
prev_sample = jax.lax.cond(
|
||||
prev_sample = jax.lax.select(
|
||||
state.lower_order_nums < 1,
|
||||
step_1,
|
||||
lambda state: jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 1,
|
||||
step_1,
|
||||
step_23,
|
||||
state,
|
||||
step_1_output,
|
||||
jax.lax.select(
|
||||
step_index == len(state.timesteps) - 1,
|
||||
step_1_output,
|
||||
step_23_output,
|
||||
),
|
||||
state,
|
||||
)
|
||||
|
||||
else:
|
||||
prev_sample = jax.lax.cond(
|
||||
prev_sample = jax.lax.select(
|
||||
state.lower_order_nums < 1,
|
||||
step_1,
|
||||
step_23,
|
||||
state,
|
||||
step_1_output,
|
||||
step_23_output,
|
||||
)
|
||||
|
||||
state = state.replace(
|
||||
lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
|
||||
step_index=(state.step_index + 1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -606,20 +622,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -233,5 +233,5 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -22,6 +22,7 @@ from scipy import integrate
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
@@ -30,15 +31,22 @@ from .scheduling_utils_flax import (
|
||||
|
||||
@flax.struct.dataclass
|
||||
class LMSDiscreteSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
|
||||
# setable values
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
sigmas: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
timesteps: Optional[jnp.ndarray] = None
|
||||
sigmas: Optional[jnp.ndarray] = None
|
||||
derivatives: jnp.ndarray = jnp.array([])
|
||||
|
||||
# running values
|
||||
derivatives: Optional[jnp.ndarray] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)
|
||||
def create(
|
||||
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -66,10 +74,18 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`jnp.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -82,24 +98,26 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
self.dtype = dtype
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
def create_state(self):
|
||||
self.state = LMSDiscreteSchedulerState.create(
|
||||
num_train_timesteps=self.config.num_train_timesteps,
|
||||
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
init_noise_sigma = sigmas.max()
|
||||
|
||||
return LMSDiscreteSchedulerState.create(
|
||||
common=common,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
|
||||
@@ -118,11 +136,13 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
||||
step_index = step_index[0]
|
||||
|
||||
sigma = state.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def get_lms_coefficient(self, state, order, t, current_order):
|
||||
def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
|
||||
"""
|
||||
Compute a linear multistep coefficient.
|
||||
|
||||
@@ -156,20 +176,28 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)
|
||||
|
||||
low_idx = jnp.floor(timesteps).astype(int)
|
||||
high_idx = jnp.ceil(timesteps).astype(int)
|
||||
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
||||
|
||||
low_idx = jnp.floor(timesteps).astype(jnp.int32)
|
||||
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
|
||||
|
||||
frac = jnp.mod(timesteps, 1.0)
|
||||
sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
|
||||
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
|
||||
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
|
||||
sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)
|
||||
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
|
||||
|
||||
timesteps = timesteps.astype(jnp.int32)
|
||||
|
||||
# initial running values
|
||||
derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
|
||||
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps.astype(int),
|
||||
derivatives=jnp.array([]),
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
num_inference_steps=num_inference_steps,
|
||||
derivatives=derivatives,
|
||||
)
|
||||
|
||||
def step(
|
||||
@@ -199,10 +227,23 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
sigma = state.sigmas[timestep]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -25,59 +24,45 @@ import jax.numpy as jnp
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
add_noise_common,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=jnp.float32)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class PNDMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
final_alpha_cumprod: jnp.ndarray
|
||||
|
||||
# setable values
|
||||
_timesteps: jnp.ndarray
|
||||
init_noise_sigma: jnp.ndarray
|
||||
timesteps: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
prk_timesteps: Optional[jnp.ndarray] = None
|
||||
plms_timesteps: Optional[jnp.ndarray] = None
|
||||
timesteps: Optional[jnp.ndarray] = None
|
||||
|
||||
# running values
|
||||
cur_model_output: Optional[jnp.ndarray] = None
|
||||
counter: int = 0
|
||||
counter: Optional[jnp.int32] = None
|
||||
cur_sample: Optional[jnp.ndarray] = None
|
||||
ets: jnp.ndarray = jnp.array([])
|
||||
ets: Optional[jnp.ndarray] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int):
|
||||
return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
final_alpha_cumprod: jnp.ndarray,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(
|
||||
common=common,
|
||||
final_alpha_cumprod=final_alpha_cumprod,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -117,10 +102,19 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
dtype: jnp.dtype
|
||||
pndm_order: int
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -136,35 +130,39 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.dtype = dtype
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
|
||||
if common is None:
|
||||
common = CommonSchedulerState.create(self)
|
||||
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
final_alpha_cumprod = (
|
||||
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
|
||||
)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
||||
|
||||
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
||||
|
||||
return PNDMSchedulerState.create(
|
||||
common=common,
|
||||
final_alpha_cumprod=final_alpha_cumprod,
|
||||
init_noise_sigma=init_noise_sigma,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
|
||||
"""
|
||||
@@ -178,42 +176,47 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
shape (`Tuple`):
|
||||
the shape of the samples to be generated.
|
||||
"""
|
||||
offset = self.config.steps_offset
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# rounding to avoid issues when num_inference_step is power of 3
|
||||
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
|
||||
|
||||
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
|
||||
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
# for some models like stable diffusion the prk steps can/should be skipped to
|
||||
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
|
||||
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
|
||||
state = state.replace(
|
||||
prk_timesteps=jnp.array([]),
|
||||
plms_timesteps=jnp.concatenate(
|
||||
[state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]]
|
||||
)[::-1],
|
||||
)
|
||||
|
||||
prk_timesteps = jnp.array([], dtype=jnp.int32)
|
||||
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
|
||||
|
||||
else:
|
||||
prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
|
||||
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
|
||||
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
|
||||
self.pndm_order,
|
||||
)
|
||||
|
||||
state = state.replace(
|
||||
prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1],
|
||||
plms_timesteps=state._timesteps[:-3][::-1],
|
||||
)
|
||||
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
|
||||
plms_timesteps = _timesteps[:-3][::-1]
|
||||
|
||||
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
|
||||
|
||||
# initial running values
|
||||
|
||||
cur_model_output = jnp.zeros(shape, dtype=self.dtype)
|
||||
counter = jnp.int32(0)
|
||||
cur_sample = jnp.zeros(shape, dtype=self.dtype)
|
||||
ets = jnp.zeros((4,) + shape, dtype=self.dtype)
|
||||
|
||||
return state.replace(
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
|
||||
counter=0,
|
||||
# Reserve space for the state variables
|
||||
cur_model_output=jnp.zeros(shape),
|
||||
cur_sample=jnp.zeros(shape),
|
||||
ets=jnp.zeros((4,) + shape),
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
prk_timesteps=prk_timesteps,
|
||||
plms_timesteps=plms_timesteps,
|
||||
cur_model_output=cur_model_output,
|
||||
counter=counter,
|
||||
cur_sample=cur_sample,
|
||||
ets=ets,
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
@@ -260,19 +263,27 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.config.skip_prk_steps:
|
||||
prev_sample, state = self.step_plms(
|
||||
state=state, model_output=model_output, timestep=timestep, sample=sample
|
||||
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
prev_sample, state = self.step_plms(state, model_output, timestep, sample)
|
||||
else:
|
||||
prev_sample, state = jax.lax.switch(
|
||||
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
|
||||
(self.step_prk, self.step_plms),
|
||||
# Args to either branch
|
||||
state,
|
||||
model_output,
|
||||
timestep,
|
||||
sample,
|
||||
prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample)
|
||||
plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample)
|
||||
|
||||
cond = state.counter < len(state.prk_timesteps)
|
||||
|
||||
prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample)
|
||||
|
||||
state = state.replace(
|
||||
cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output),
|
||||
ets=jax.lax.select(cond, prk_state.ets, plms_state.ets),
|
||||
cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample),
|
||||
counter=jax.lax.select(cond, prk_state.counter, plms_state.counter),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -304,6 +315,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
@@ -315,37 +327,34 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
|
||||
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return (
|
||||
state.replace(
|
||||
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
|
||||
ets=state.ets.at[ets_at].set(model_output),
|
||||
cur_sample=sample,
|
||||
),
|
||||
model_output,
|
||||
)
|
||||
model_output = jax.lax.select(
|
||||
(state.counter % 4) != 3,
|
||||
model_output, # remainder 0, 1, 2
|
||||
state.cur_model_output + 1 / 6 * model_output, # remainder 3
|
||||
)
|
||||
|
||||
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
|
||||
|
||||
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
|
||||
|
||||
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
model_output = state.cur_model_output + 1 / 6 * model_output
|
||||
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
|
||||
|
||||
state, model_output = jax.lax.switch(
|
||||
state.counter % 4,
|
||||
(remainder_0, remainder_1, remainder_2, remainder_3),
|
||||
# Args to either branch
|
||||
state,
|
||||
model_output,
|
||||
state.counter // 4,
|
||||
state = state.replace(
|
||||
cur_model_output=jax.lax.select_n(
|
||||
state.counter % 4,
|
||||
state.cur_model_output + 1 / 6 * model_output, # remainder 0
|
||||
state.cur_model_output + 1 / 3 * model_output, # remainder 1
|
||||
state.cur_model_output + 1 / 3 * model_output, # remainder 2
|
||||
jnp.zeros_like(state.cur_model_output), # remainder 3
|
||||
),
|
||||
ets=jax.lax.select(
|
||||
(state.counter % 4) == 0,
|
||||
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0
|
||||
state.ets, # remainder 1, 2, 3
|
||||
),
|
||||
cur_sample=jax.lax.select(
|
||||
(state.counter % 4) == 0,
|
||||
sample, # remainder 0
|
||||
state.cur_sample, # remainder 1, 2, 3
|
||||
),
|
||||
)
|
||||
|
||||
cur_sample = state.cur_sample
|
||||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
||||
prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
return (prev_sample, state)
|
||||
@@ -374,18 +383,13 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if state.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if not self.config.skip_prk_steps and len(state.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
"in 'prk' mode for at least 12 iterations "
|
||||
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
|
||||
"for more information."
|
||||
)
|
||||
# NOTE: There is no way to check in the jitted runtime if the prk mode was ran before
|
||||
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
|
||||
@@ -417,64 +421,39 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# else:
|
||||
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
|
||||
|
||||
def counter_0(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[0].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_sample=sample,
|
||||
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
|
||||
)
|
||||
state = state.replace(
|
||||
ets=jax.lax.select(
|
||||
state.counter != 1,
|
||||
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1
|
||||
state.ets, # counter 1
|
||||
),
|
||||
cur_sample=jax.lax.select(
|
||||
state.counter != 1,
|
||||
sample, # counter != 1
|
||||
state.cur_sample, # counter 1
|
||||
),
|
||||
)
|
||||
|
||||
def counter_1(state: PNDMSchedulerState):
|
||||
return state.replace(
|
||||
cur_model_output=(model_output + state.ets[0]) / 2,
|
||||
)
|
||||
|
||||
def counter_2(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[1].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=(3 * ets[1] - ets[0]) / 2,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
def counter_3(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[2].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
def counter_other(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[3].set(model_output)
|
||||
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
|
||||
|
||||
ets = ets.at[0].set(ets[1])
|
||||
ets = ets.at[1].set(ets[2])
|
||||
ets = ets.at[2].set(ets[3])
|
||||
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=next_model_output,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
counter = jnp.clip(state.counter, 0, 4)
|
||||
state = jax.lax.switch(
|
||||
counter,
|
||||
[counter_0, counter_1, counter_2, counter_3, counter_other],
|
||||
state,
|
||||
state = state.replace(
|
||||
cur_model_output=jax.lax.select_n(
|
||||
jnp.clip(state.counter, 0, 4),
|
||||
model_output, # counter 0
|
||||
(model_output + state.ets[-1]) / 2, # counter 1
|
||||
(3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2
|
||||
(23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3
|
||||
(1 / 24)
|
||||
* (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4
|
||||
),
|
||||
)
|
||||
|
||||
sample = state.cur_sample
|
||||
model_output = state.cur_model_output
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
return (prev_sample, state)
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
@@ -487,11 +466,20 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
elif self.config.prediction_type != "epsilon":
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
|
||||
)
|
||||
|
||||
# corresponds to (α_(t−δ) - α_t) divided by
|
||||
# denominator of x_t in formula (9) and plus 1
|
||||
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
|
||||
@@ -512,20 +500,12 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
@@ -50,6 +52,7 @@ class FlaxSchedulerMixin:
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
ignore_for_config = ["dtype"]
|
||||
_compatibles = []
|
||||
has_compatibles = True
|
||||
|
||||
@@ -167,3 +170,90 @@ class FlaxSchedulerMixin:
|
||||
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
|
||||
assert len(shape) >= x.ndim
|
||||
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=dtype)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class CommonSchedulerState:
|
||||
alphas: jnp.ndarray
|
||||
betas: jnp.ndarray
|
||||
alphas_cumprod: jnp.ndarray
|
||||
|
||||
@classmethod
|
||||
def create(cls, scheduler):
|
||||
config = scheduler.config
|
||||
|
||||
if config.trained_betas is not None:
|
||||
betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
|
||||
elif config.beta_schedule == "linear":
|
||||
betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
betas = (
|
||||
jnp.linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif config.beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
|
||||
)
|
||||
|
||||
alphas = 1.0 - betas
|
||||
|
||||
alphas_cumprod = jnp.cumprod(alphas, axis=0)
|
||||
|
||||
return cls(
|
||||
alphas=alphas,
|
||||
betas=betas,
|
||||
alphas_cumprod=alphas_cumprod,
|
||||
)
|
||||
|
||||
|
||||
def add_noise_common(
|
||||
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
||||
):
|
||||
alphas_cumprod = state.alphas_cumprod
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -296,10 +296,11 @@ class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -577,12 +578,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
|
||||
Reference in New Issue
Block a user