1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2023-10-24 00:50:45 +00:00
parent a5512d49d5
commit 5f9b153183

View File

@@ -698,11 +698,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -733,11 +728,20 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds_cond = (
torch.cat([negative_prompt_embeds, prompt_embeds])
if do_classifier_free_guidance
else prompt_embeds
)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds_cond,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
@@ -758,13 +762,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
callback_kwargs = {}
for k in callback_on_step_end_inputs:
callback_kwargs[k] = locals()[k]
callback_kwargs = callback_on_step_end(i, t, callback_kwargs)
callback_results = callback_on_step_end(i, t, callback_kwargs)
latents = callback_kwargs.pop("latents", latents)
guidance_scale = callback_kwargs.pop("guidance_scale", guidance_scale)
prompt_embeds = callback_kwargs.pop("prompt_embeds", prompt_embeds)
cross_attention_kwargs = callback_kwargs.pop("cross_attention_kwargs", cross_attention_kwargs)
guidance_rescale = callback_kwargs.pop("guidance_rescale", guidance_rescale)
latents = callback_results.pop("latents", latents)
guidance_scale = callback_results.pop("guidance_scale", guidance_scale)
prompt_embeds = callback_results.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_results.pop("negative_prompt_embeds", negative_prompt_embeds)
cross_attention_kwargs = callback_results.pop("cross_attention_kwargs", cross_attention_kwargs)
guidance_rescale = callback_results.pop("guidance_rescale", guidance_rescale)
do_classifier_free_guidance = guidance_scale > 1.0
# call the callback, if provided