diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 704f72f9b6..a9143053a6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -11,49 +11,104 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +import math +import numpy as np + +from torch import nn +import torch.nn.functional as F -# unet.py -def get_timestep_embedding(timesteps, embedding_dim): +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): """ This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - -# unet_glide.py -def timestep_embedding(timesteps, dim, max_period=10000): - """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param dim: the dimension of the output. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift)) + + emb = emb.to(device=timesteps.device) + emb = timesteps[:, None].float() * emb[None, :] + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +#def get_timestep_embedding(timesteps, embedding_dim): +# """ +# This matches the implementation in Denoising Diffusion Probabilistic Models: +# From Fairseq. +# Build sinusoidal embeddings. +# This matches the implementation in tensor2tensor, but differs slightly +# from the description in Section 3.5 of "Attention Is All You Need". +# """ +# assert len(timesteps.shape) == 1 +# +# half_dim = embedding_dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) +# emb = emb.to(device=timesteps.device) +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + + +#def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None, :] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding + + +#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): +# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 +# half_dim = embedding_dim // 2 + # magic number 10000 is from transformers +# emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = F.pad(emb, (0, 1), mode="constant") +# assert emb.shape == (timesteps.shape[0], embedding_dim) +# return emb + # unet_grad_tts.py class SinusoidalPosEmb(torch.nn.Module): @@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -# unet_ldm.py -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding # unet_rl.py class SinusoidalPosEmb(nn.Module): @@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -# unet_sde_score_estimation.py -def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): - assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 - half_dim = embedding_dim // 2 - # magic number 10000 is from transformers - emb = math.log(max_positions) / (half_dim - 1) - # emb = math.log(2.) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) - # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] - # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode="constant") - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb # unet_sde_score_estimation.py class GaussianFourierProjection(nn.Module): diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index a4e1e22df8..7d5eebfd3d 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -30,27 +30,28 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb +#def get_timestep_embedding(timesteps, embedding_dim): +# """ +# This matches the implementation in Denoising Diffusion Probabilistic Models: +# From Fairseq. +# Build sinusoidal embeddings. +# This matches the implementation in tensor2tensor, but differs slightly +# from the description in Section 3.5 of "Attention Is All You Need". +# """ +# assert len(timesteps.shape) == 1 +# +# half_dim = embedding_dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) +# emb = emb.to(device=timesteps.device) +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) +# return emb def nonlinearity(x): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 648ff9c34a..0e04537766 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def convert_module_to_f16(l): @@ -86,25 +87,25 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding +# def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding def zero_module(module): @@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) h = x.type(self.dtype) for module in self.input_blocks: @@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel): def forward(self, x, timesteps, transformer_out=None): hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) # project the last token transformer_proj = self.transformer_proj(transformer_out[:, -1]) @@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel): x = torch.cat([x, upsampled], dim=1) hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) h = x for module in self.input_blocks: diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index cca3231341..cfc200bf6a 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -16,6 +16,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def exists(val): @@ -316,34 +317,25 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module +#def timestep_embedding(timesteps, dim, max_period=10000): +# """ +# Create sinusoidal timestep embeddings. +# +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# half = dim // 2 +# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( +# device=timesteps.device +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# return embedding ## go @@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): hs = [] if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) - t_emb = timestep_embedding(timesteps, self.model_channels) + t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) emb = self.time_embed(t_emb) if self.num_classes is not None: @@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) results = [] h = x.type(self.dtype) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 299f96c9cd..7d00eb2174 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -26,6 +26,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): @@ -381,21 +382,21 @@ def get_act(nonlinearity): raise NotImplementedError("activation function does not exist!") -def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): - assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 - half_dim = embedding_dim // 2 +#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): +# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 +# half_dim = embedding_dim // 2 # magic number 10000 is from transformers - emb = math.log(max_positions) / (half_dim - 1) +# emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) +# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = F.pad(emb, (0, 1), mode="constant") - assert emb.shape == (timesteps.shape[0], embedding_dim) - return emb +# emb = timesteps.float()[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# if embedding_dim % 2 == 1: # zero pad +# emb = F.pad(emb, (0, 1), mode="constant") +# assert emb.shape == (timesteps.shape[0], embedding_dim) +# return emb def default_init(scale=1.0): diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index db4ed6eb02..0b50e7bc86 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -21,718 +21,24 @@ import unittest import numpy as np import torch -from diffusers import ( - BDDMPipeline, - DDIMPipeline, - DDIMScheduler, - DDPMPipeline, - DDPMScheduler, - GlidePipeline, - GlideSuperResUNetModel, - GlideTextToImageUNetModel, - GradTTSPipeline, - GradTTSScheduler, - LatentDiffusionPipeline, - PNDMPipeline, - PNDMScheduler, - UNetGradTTSModel, - UNetLDMModel, - UNetModel, -) -from diffusers.configuration_utils import ConfigMixin -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.pipeline_bddm import DiffWave +#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding +from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding from diffusers.testing_utils import floats_tensor, slow, torch_device torch.backends.cuda.matmul.allow_tf32 = False -class ConfigTester(unittest.TestCase): - def test_load_not_from_mixin(self): - with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") +class EmbeddingsTests(unittest.TestCase): - def test_save_load(self): - class SampleObject(ConfigMixin): - config_name = "config.json" + def test_timestep_embeddings(self): + embedding_dim = 16 + timesteps = torch.arange(10) - def __init__( - self, - a=2, - b=5, - c=(2, 5), - d="for diffusion", - e=[1, 3], - ): - self.register_to_config(a=a, b=b, c=c, d=d, e=e) + t1 = get_timestep_embedding(timesteps, embedding_dim) + t2 = timestep_embedding(timesteps, embedding_dim) + t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8) - obj = SampleObject() - config = obj.config + import ipdb; ipdb.set_trace() - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) - new_config = new_obj.config - - # unfreeze configs - config = dict(config) - new_config = dict(new_config) - - assert config.pop("c") == (2, 5) # instantiated as tuple - assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json - assert config == new_config - - -class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - new_image = new_model(**inputs_dict) - - max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") - - def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - with torch.no_grad(): - first = model(**inputs_dict) - second = model(**inputs_dict) - - out_1 = first.cpu().numpy() - out_2 = second.cpu().numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["x", "timesteps"] - self.assertListEqual(arg_names[:2], expected_arg_names) - - def test_model_from_config(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all paramters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) - - with torch.no_grad(): - output_1 = model(**inputs_dict) - output_2 = new_model(**inputs_dict) - - self.assertEqual(output_1.shape, output_2.shape) - - def test_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - output = model(**inputs_dict) - noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - - -class UnetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"x": noise, "timesteps": time_step} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 32, - "ch_mult": (1, 2), - "num_res_blocks": 2, - "attn_resolutions": (16,), - "resolution": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetModel.from_pretrained("fusing/ddpm_dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - time_step = torch.tensor([10]) - - with torch.no_grad(): - output = model(noise, time_step) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): - model_class = GlideSuperResUNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 6 - sizes = (32, 32) - low_res_size = (4, 4) - - noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) - low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - return {"x": noise, "timesteps": time_step, "low_res": low_res} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (6, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "attention_resolutions": (2,), - "channel_mult": (1, 2), - "in_channels": 6, - "out_channels": 6, - "model_channels": 32, - "num_head_channels": 8, - "num_heads_upsample": 1, - "num_res_blocks": 2, - "resblock_updown": True, - "resolution": 32, - "use_scale_shift_norm": True, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - output, _ = torch.split(output, 3, dim=1) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_from_pretrained_hub(self): - model, loading_info = GlideSuperResUNetModel.from_pretrained( - "fusing/glide-super-res-dummy", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, 3, 64, 64) - low_res = torch.randn(1, 3, 4, 4) - time_step = torch.tensor([42] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step, low_res) - - output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = GlideTextToImageUNetModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - transformer_dim = 32 - seq_len = 16 - - noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) - emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - return {"x": noise, "timesteps": time_step, "transformer_out": emb} - - @property - def get_input_shape(self): - return (3, 32, 32) - - @property - def get_output_shape(self): - return (6, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "attention_resolutions": (2,), - "channel_mult": (1, 2), - "in_channels": 3, - "out_channels": 6, - "model_channels": 32, - "num_head_channels": 8, - "num_heads_upsample": 1, - "num_res_blocks": 2, - "resblock_updown": True, - "resolution": 32, - "use_scale_shift_norm": True, - "transformer_dim": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - output, _ = torch.split(output, 3, dim=1) - - self.assertIsNotNone(output) - expected_shape = inputs_dict["x"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_from_pretrained_hub(self): - model, loading_info = GlideTextToImageUNetModel.from_pretrained( - "fusing/unet-glide-text2im-dummy", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to( - torch_device - ) - emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) - time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - - with torch.no_grad(): - output = model(noise, time_step, emb) - - output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetLDMModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"x": noise, "timesteps": time_step} - - @property - def get_input_shape(self): - return (4, 32, 32) - - @property - def get_output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "image_size": 32, - "in_channels": 4, - "out_channels": 4, - "model_channels": 32, - "num_res_blocks": 2, - "attention_resolutions": (16,), - "channel_mult": (1, 2), - "num_heads": 2, - "conv_resample": True, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) - time_step = torch.tensor([10] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetGradTTSModel - - @property - def dummy_input(self): - batch_size = 4 - num_features = 32 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} - - @property - def get_input_shape(self): - return (4, 32, 16) - - @property - def get_output_shape(self): - return (4, 32, 16) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "dim": 64, - "groups": 4, - "dim_mults": (1, 2), - "n_feats": 32, - "pe_scale": 1000, - "n_spks": 1, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = model.config.n_feats - seq_len = 16 - noise = torch.randn((1, num_features, seq_len)) - condition = torch.randn((1, num_features, seq_len)) - mask = torch.randn((1, 1, seq_len)) - time_step = torch.tensor([10]) - - with torch.no_grad(): - output = model(noise, time_step, condition, mask) - - output_slice = output[0, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -class PipelineTesterMixin(unittest.TestCase): - def test_from_pretrained_save_pretrained(self): - # 1. Load models - model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) - schedular = DDPMScheduler(timesteps=10) - - ddpm = DDPMPipeline(model, schedular) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator) - generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator) - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_from_pretrained_hub(self): - model_path = "fusing/ddpm-cifar10" - - ddpm = DDPMPipeline.from_pretrained(model_path) - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) - - ddpm.noise_scheduler.num_timesteps = 10 - ddpm_from_hub.noise_scheduler.num_timesteps = 10 - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator) - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator) - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_ddpm_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = DDPMScheduler.from_config(model_id) - noise_scheduler = noise_scheduler.set_format("pt") - - ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = ddpm(generator=generator) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ddim_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = DDIMScheduler(tensor_format="pt") - - ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = ddim(generator=generator, eta=0.0) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_pndm_cifar10(self): - generator = torch.manual_seed(0) - model_id = "fusing/ddpm-cifar10" - - unet = UNetModel.from_pretrained(model_id) - noise_scheduler = PNDMScheduler(tensor_format="pt") - - pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) - image = pndm(generator=generator) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ldm_text2img(self): - model_id = "fusing/latent-diffusion-text2im-large" - ldm = LatentDiffusionPipeline.from_pretrained(model_id) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_glide_text2img(self): - model_id = "fusing/glide-base" - glide = GlidePipeline.from_pretrained(model_id) - - prompt = "a pencil sketch of a corgi" - generator = torch.manual_seed(0) - image = glide(prompt, generator=generator, num_inference_steps_upscale=20) - - image_slice = image[0, :3, :3, -1].cpu() - - assert image.shape == (1, 256, 256, 3) - expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_grad_tts(self): - model_id = "fusing/grad-tts-libri-tts" - grad_tts = GradTTSPipeline.from_pretrained(model_id) - noise_scheduler = GradTTSScheduler() - grad_tts.noise_scheduler = noise_scheduler - - text = "Hello world, I missed you so much." - generator = torch.manual_seed(0) - - # generate mel spectograms using text - mel_spec = grad_tts(text, generator=generator) - - assert mel_spec.shape == (1, 80, 143) - expected_slice = torch.tensor( - [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890] - ) - assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 - - def test_module_from_pipeline(self): - model = DiffWave(num_res_layers=4) - noise_scheduler = DDPMScheduler(timesteps=12) - - bddm = BDDMPipeline(model, noise_scheduler) - - # check if the library name for the diffwave moduel is set to pipeline module - self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") - - # check if we can save and load the pipeline - with tempfile.TemporaryDirectory() as tmpdirname: - bddm.save_pretrained(tmpdirname) - _ = BDDMPipeline.from_pretrained(tmpdirname) - # check if the same works using the DifusionPipeline class - _ = DiffusionPipeline.from_pretrained(tmpdirname)