mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix dtype error for StableDiffusionXL (#9217)
Fix dtype error Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1084,7 +1084,7 @@ def main(args):
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
|
||||
|
||||
# time ids
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
@@ -1101,7 +1101,7 @@ def main(args):
|
||||
|
||||
# Predict the noise residual
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
|
||||
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
model_pred = unet(
|
||||
|
||||
Reference in New Issue
Block a user