mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (#864)
* added broadcast_to_shape_from_left helper * initial tests * fixed pndm tests * shape required for pndm * added require_flax * fix style * fix more imports Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ if is_flax_available():
|
||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
||||
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
else:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> DDIMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
key (`random.KeyArray`): a PRNG key.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -23,7 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
self.one = jnp.array(1.0)
|
||||
|
||||
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)
|
||||
def create_state(self):
|
||||
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> DDPMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
A reasonable range is [0.2, 80].
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
):
|
||||
self.state = KarrasVeSchedulerState.create()
|
||||
pass
|
||||
|
||||
def create_state(self):
|
||||
return KarrasVeSchedulerState.create()
|
||||
|
||||
def set_timesteps(
|
||||
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> KarrasVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
|
||||
def create_state(self):
|
||||
self.state = LMSDiscreteSchedulerState.create(
|
||||
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
num_train_timesteps=self.config.num_train_timesteps,
|
||||
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
|
||||
)
|
||||
|
||||
def get_lms_coefficient(self, state, order, t, current_order):
|
||||
@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sigma = state.sigmas[timesteps].flatten()
|
||||
while len(sigma.shape) < len(noise.shape):
|
||||
sigma = sigma[..., None]
|
||||
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `FlaxPNDMScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
shape (`Tuple`):
|
||||
the shape of the samples to be generated.
|
||||
"""
|
||||
offset = self.config.steps_offset
|
||||
|
||||
@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -22,7 +22,7 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
correct_steps (`int`): number of correction steps performed on a produced sample.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
):
|
||||
state = ScoreSdeVeSchedulerState.create()
|
||||
pass
|
||||
|
||||
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
def create_state(self):
|
||||
state = ScoreSdeVeSchedulerState.create()
|
||||
return self.set_sigmas(
|
||||
state,
|
||||
self.config.num_train_timesteps,
|
||||
self.config.sigma_min,
|
||||
self.config.sigma_max,
|
||||
self.config.sampling_eps,
|
||||
)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
|
||||
# also equation 47 shows the analog from SDE models to ancestral sampling methods
|
||||
diffusion = diffusion.flatten()
|
||||
while len(diffusion.shape) < len(sample.shape):
|
||||
diffusion = diffusion[:, None]
|
||||
diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
|
||||
drift = drift - diffusion**2 * model_output
|
||||
|
||||
# equation 6: sample noise for the diffusion term of
|
||||
@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
# compute corrected sample: model_output term and noise term
|
||||
step_size = step_size.flatten()
|
||||
while len(step_size.shape) < len(sample.shape):
|
||||
step_size = step_size[:, None]
|
||||
step_size = broadcast_to_shape_from_left(step_size, sample.shape)
|
||||
prev_sample_mean = sample + step_size * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
@@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
||||
|
||||
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
|
||||
assert len(shape) >= x.ndim
|
||||
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
|
||||
|
||||
863
tests/test_scheduler_flax.py
Normal file
863
tests/test_scheduler_flax.py
Normal file
@@ -0,0 +1,863 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
sample = random.uniform(key1, (batch_size, num_channels, height, width))
|
||||
|
||||
return sample, key2
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * height * width
|
||||
sample = jnp.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, height, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
return jnp.transpose(sample, (3, 0, 1, 2))
|
||||
|
||||
def get_scheduler_config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def dummy_model(self):
|
||||
def model(sample, t, *args):
|
||||
return sample * t / (t + 1)
|
||||
|
||||
return model
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxDDPMScheduler,)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"variance_type": "fixed_small",
|
||||
"clip_sample": True,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [1, 5, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_variance_type(self):
|
||||
for variance in ["fixed_small", "fixed_large", "other"]:
|
||||
self.check_over_configs(variance_type=variance)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 500, 999]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_variance(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
num_trained_timesteps = len(scheduler)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
|
||||
for t in reversed(range(num_trained_timesteps)):
|
||||
# 1. predict noise residual
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
output = scheduler.step(state, residual, t, sample, key1)
|
||||
pred_prev_sample = output.prev_sample
|
||||
state = output.state
|
||||
key1, key2 = random.split(key2)
|
||||
|
||||
# if t > 0:
|
||||
# noise = self.dummy_sample_deter
|
||||
# variance = scheduler.get_variance(t) ** (0.5) * noise
|
||||
#
|
||||
# sample = pred_prev_sample + variance
|
||||
sample = pred_prev_sample
|
||||
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 255.1113) < 1e-2
|
||||
assert abs(result_mean - 0.332176) < 1e-3
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxDDIMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
|
||||
num_inference_steps = 10
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
|
||||
for t in state.timesteps:
|
||||
residual = model(sample, t)
|
||||
output = scheduler.step(state, residual, t, sample)
|
||||
sample = output.prev_sample
|
||||
state = output.state
|
||||
key1, key2 = random.split(key2)
|
||||
|
||||
return sample
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 500, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, 5)
|
||||
assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 10, 49]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_variance(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 172.0067) < 1e-2
|
||||
assert abs(result_mean - 0.223967) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 149.8295) < 1e-2
|
||||
assert abs(result_mean - 0.1951) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 149.0784) < 1e-2
|
||||
assert abs(result_mean - 0.1941) < 1e-3
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxPNDMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
state = state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
|
||||
(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
|
||||
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
|
||||
new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
|
||||
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
for i, t in enumerate(state.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample, state = scheduler.step_prk(state, residual, t, sample)
|
||||
|
||||
for i, t in enumerate(state.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample, state = scheduler.step_plms(state, residual, t, sample)
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
state = state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
|
||||
output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
|
||||
output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, 10, shape=())
|
||||
assert jnp.equal(
|
||||
state.timesteps,
|
||||
jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
|
||||
).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_pow_of_3_inference_steps(self):
|
||||
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
|
||||
num_inference_steps = 27
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# before power of 3 fix, would error on first step, so we only need to do two
|
||||
for i, t in enumerate(state.prk_timesteps[:2]):
|
||||
sample, state = scheduler.step_prk(state, residual, t, sample)
|
||||
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 198.1318) < 1e-2
|
||||
assert abs(result_mean - 0.2580) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 186.9466) < 1e-2
|
||||
assert abs(result_mean - 0.24342) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 186.9482) < 1e-2
|
||||
assert abs(result_mean - 0.2434) < 1e-3
|
||||
Reference in New Issue
Block a user