From 5f9b153183d0cde3b850f14024d2e37ae8c19576 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 24 Oct 2023 00:50:45 +0000 Subject: [PATCH] update --- .../pipeline_stable_diffusion.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2d78fc3e03..8fe8a10714 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -698,11 +698,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -733,11 +728,20 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds_cond = ( + torch.cat([negative_prompt_embeds, prompt_embeds]) + if do_classifier_free_guidance + else prompt_embeds + ) + # predict the noise residual noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_cond, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] @@ -758,13 +762,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo callback_kwargs = {} for k in callback_on_step_end_inputs: callback_kwargs[k] = locals()[k] - callback_kwargs = callback_on_step_end(i, t, callback_kwargs) + callback_results = callback_on_step_end(i, t, callback_kwargs) - latents = callback_kwargs.pop("latents", latents) - guidance_scale = callback_kwargs.pop("guidance_scale", guidance_scale) - prompt_embeds = callback_kwargs.pop("prompt_embeds", prompt_embeds) - cross_attention_kwargs = callback_kwargs.pop("cross_attention_kwargs", cross_attention_kwargs) - guidance_rescale = callback_kwargs.pop("guidance_rescale", guidance_rescale) + latents = callback_results.pop("latents", latents) + guidance_scale = callback_results.pop("guidance_scale", guidance_scale) + prompt_embeds = callback_results.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_results.pop("negative_prompt_embeds", negative_prompt_embeds) + cross_attention_kwargs = callback_results.pop("cross_attention_kwargs", cross_attention_kwargs) + guidance_rescale = callback_results.pop("guidance_rescale", guidance_rescale) do_classifier_free_guidance = guidance_scale > 1.0 # call the callback, if provided