diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py new file mode 100644 index 0000000000..333aeb85d5 --- /dev/null +++ b/src/diffusers/models/embeddings.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# unet.py +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + +# unet_glide.py +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index a2bdd951e4..81719f088b 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -198,6 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) + t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.mlp(t) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c5c115f19..4556753006 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -113,7 +113,7 @@ class ModelTesterMixin: new_image = new_model(**inputs_dict) max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") + self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -431,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + model.to(torch_device) with torch.no_grad(): output = model(noise, time_step, emb) output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() + output_slice = output[0, -1, -3:, -3:].cpu().flatten() # fmt: off expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) # fmt: on