mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix indexing issue in sd reference pipeline (#4531)
This commit is contained in:
@@ -153,8 +153,6 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
)
|
||||
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
@@ -733,6 +731,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
1,
|
||||
),
|
||||
)
|
||||
ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
|
||||
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
|
||||
|
||||
MODE = "write"
|
||||
|
||||
Reference in New Issue
Block a user