1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix dtype error for StableDiffusionXL (#9217)

Fix dtype error

Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Leo Jiang
2024-08-20 09:28:50 +08:00
committed by GitHub
parent 803e817e3e
commit eda36c4c28

View File

@@ -1084,7 +1084,7 @@ def main(args):
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
# time ids
def compute_time_ids(original_size, crops_coords_top_left):
@@ -1101,7 +1101,7 @@ def main(args):
# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(