mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
class labels timestep embeddings projection dtype cast (#3137)
This mimics the dtype cast for the standard time embeddings
This commit is contained in:
@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
|
||||
@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
|
||||
Reference in New Issue
Block a user