diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dc39480b65..8d3f7cbbe3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1665,8 +1665,7 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module): torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1) ) input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1) - input_vec.to(dtype=timestep.dtype) - return input_vec + return input_vec.to(dtype=timestep.dtype) class CogView3CombinedTimestepSizeEmbeddings(nn.Module): def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):