mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. (#7381)
* fix * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
)
|
||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||
|
||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
||||
batch_size, channels, height, width = image_embeddings.shape
|
||||
def prepare_latents(
|
||||
self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
|
||||
):
|
||||
_, channels, height, width = image_embeddings.shape
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
)
|
||||
if isinstance(image_embeddings, list):
|
||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||
batch_size = image_embeddings.shape[0]
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Compute the effective number of images per prompt
|
||||
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
|
||||
# This results in a case where a single prompt is associated with multiple image embeddings
|
||||
# Divide the number of image embeddings by the batch size to determine if this is the case.
|
||||
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
|
||||
|
||||
# 2. Encode caption
|
||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||
@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
# 6. Run denoising loop
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
||||
|
||||
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = StableCascadeDecoderPipeline(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prior_num_images_per_prompt = 2
|
||||
decoder_num_images_per_prompt = 2
|
||||
prompt = ["a cat"]
|
||||
batch_size = len(prompt)
|
||||
|
||||
generator = torch.Generator(device)
|
||||
image_embeddings = randn_tensor(
|
||||
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||
)
|
||||
decoder_output = pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
guidance_scale=0.0,
|
||||
generator=generator.manual_seed(0),
|
||||
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||
)
|
||||
|
||||
assert decoder_output.images.shape[0] == (
|
||||
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||
)
|
||||
|
||||
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = StableCascadeDecoderPipeline(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prior_num_images_per_prompt = 2
|
||||
decoder_num_images_per_prompt = 2
|
||||
prompt = ["a cat"]
|
||||
batch_size = len(prompt)
|
||||
|
||||
generator = torch.Generator(device)
|
||||
image_embeddings = randn_tensor(
|
||||
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||
)
|
||||
decoder_output = pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
guidance_scale=2.0,
|
||||
generator=generator.manual_seed(0),
|
||||
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||
)
|
||||
|
||||
assert decoder_output.images.shape[0] == (
|
||||
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user