mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -698,11 +698,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -733,11 +728,20 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds_cond = (
|
||||
torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
if do_classifier_free_guidance
|
||||
else prompt_embeds
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds_cond,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -758,13 +762,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_kwargs = callback_on_step_end(i, t, callback_kwargs)
|
||||
callback_results = callback_on_step_end(i, t, callback_kwargs)
|
||||
|
||||
latents = callback_kwargs.pop("latents", latents)
|
||||
guidance_scale = callback_kwargs.pop("guidance_scale", guidance_scale)
|
||||
prompt_embeds = callback_kwargs.pop("prompt_embeds", prompt_embeds)
|
||||
cross_attention_kwargs = callback_kwargs.pop("cross_attention_kwargs", cross_attention_kwargs)
|
||||
guidance_rescale = callback_kwargs.pop("guidance_rescale", guidance_rescale)
|
||||
latents = callback_results.pop("latents", latents)
|
||||
guidance_scale = callback_results.pop("guidance_scale", guidance_scale)
|
||||
prompt_embeds = callback_results.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_results.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
cross_attention_kwargs = callback_results.pop("cross_attention_kwargs", cross_attention_kwargs)
|
||||
guidance_rescale = callback_results.pop("guidance_rescale", guidance_rescale)
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# call the callback, if provided
|
||||
|
||||
Reference in New Issue
Block a user