diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index a2fd849ac0..68e30f15bc 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -153,8 +153,6 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ) ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) - ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents - # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) return ref_image_latents @@ -733,6 +731,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): 1, ), ) + ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt ref_xt = self.scheduler.scale_model_input(ref_xt, t) MODE = "write"