1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

class labels timestep embeddings projection dtype cast (#3137)

This mimics the dtype cast for the standard time embeddings
This commit is contained in:
Will Berman
2023-04-18 15:05:41 -07:00
committed by GitHub
parent f0c74e9a75
commit fc1883918f
2 changed files with 10 additions and 2 deletions

View File

@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat:

View File

@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat: