1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

move to the end of loop

This commit is contained in:
yiyixuxu
2023-10-19 05:39:54 +00:00
parent fe63843c33
commit a5512d49d5

View File

@@ -729,19 +729,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_inputs:
callback_kwargs[k] = locals()[k]
callback_kwargs = callback_on_step_end(i, t, callback_kwargs)
latents = callback_kwargs.pop("latents", latents)
guidance_scale = callback_kwargs.pop("guidance_scale", guidance_scale)
prompt_embeds = callback_kwargs.pop("prompt_embeds", prompt_embeds)
cross_attention_kwargs = callback_kwargs.pop("cross_attention_kwargs", cross_attention_kwargs)
guidance_rescale = callback_kwargs.pop("guidance_rescale", guidance_rescale)
do_classifier_free_guidance = guidance_scale > 1.0
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -767,6 +754,19 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_inputs:
callback_kwargs[k] = locals()[k]
callback_kwargs = callback_on_step_end(i, t, callback_kwargs)
latents = callback_kwargs.pop("latents", latents)
guidance_scale = callback_kwargs.pop("guidance_scale", guidance_scale)
prompt_embeds = callback_kwargs.pop("prompt_embeds", prompt_embeds)
cross_attention_kwargs = callback_kwargs.pop("cross_attention_kwargs", cross_attention_kwargs)
guidance_rescale = callback_kwargs.pop("guidance_rescale", guidance_rescale)
do_classifier_free_guidance = guidance_scale > 1.0
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()