From c91272d63155faaf35e62f3e2e8a3b17fa28610a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 8 Aug 2023 15:14:19 +0200 Subject: [PATCH] fix indexing issue in sd reference pipeline (#4531) --- examples/community/stable_diffusion_reference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index a2fd849ac0..68e30f15bc 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -153,8 +153,6 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ) ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) - ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents - # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) return ref_image_latents @@ -733,6 +731,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): 1, ), ) + ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt ref_xt = self.scheduler.scale_model_input(ref_xt, t) MODE = "write"