1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. (#7381)

* fix

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Dhruv Nair
2024-03-19 23:10:14 +05:30
committed by GitHub
parent b09a2aa308
commit 80ff4ba63e
2 changed files with 79 additions and 4 deletions

View File

@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
)
self.register_to_config(latent_dim_scale=latent_dim_scale)
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
batch_size, channels, height, width = image_embeddings.shape
def prepare_latents(
self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
):
_, channels, height, width = image_embeddings.shape
latents_shape = (
batch_size * num_images_per_prompt,
4,
@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
)
if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0)
batch_size = image_embeddings.shape[0]
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Compute the effective number of images per prompt
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
# This results in a case where a single prompt is associated with multiple image embeddings
# Divide the number of image embeddings by the batch size to determine if this is the case.
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
# 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None:
@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 5. Prepare latents
latents = self.prepare_latents(
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
)
# 6. Run denoising loop

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin
@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=0.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=2.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
@slow
@require_torch_gpu