From c8eaea53bfe5ce3421f4fbe0ae62ee12c3f69804 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 18 Apr 2023 15:05:41 -0700 Subject: [PATCH] class labels timestep embeddings projection dtype cast (#3137) This mimics the dtype cast for the standard time embeddings --- src/diffusers/models/unet_2d_condition.py | 6 +++++- .../pipelines/versatile_diffusion/modeling_text_unet.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 29de8734d4..b4997a2576 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) @@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) if self.config.class_embeddings_concat: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index b20f18c485..2a7b80d01d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) @@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) if self.config.class_embeddings_concat: