diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index a0daa73828..a3206cd331 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -327,7 +327,7 @@ class UniDiffuserPipeline(DiffusionPipeline): def set_joint_mode(self): self.mode = "joint" - + def reset_mode(self): self.mode = None @@ -349,10 +349,10 @@ class UniDiffuserPipeline(DiffusionPipeline): num_images_per_prompt = 1 if num_prompts_per_image is None: num_prompts_per_image = 1 - + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" - + if mode in ["text2img"]: if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -666,7 +666,9 @@ class UniDiffuserPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def prepare_text_latents(self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None): + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): # Prepare latents for the CLIP embedded prompt. shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) if isinstance(generator, list) and len(generator) != batch_size: @@ -678,7 +680,7 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, 'B L D -> (repeat B) L D', repeat=num_images_per_prompt) + latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_images_per_prompt) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler @@ -688,9 +690,23 @@ class UniDiffuserPipeline(DiffusionPipeline): # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Rename: prepare_latents -> prepare_image_vae_latents def prepare_image_vae_latents( - self, batch_size, num_prompts_per_image, num_channels_latents, height, width, dtype, device, generator, latents=None + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, ): - shape = (batch_size * num_prompts_per_image, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -700,14 +716,16 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, 'B C H W -> (repeat B) C H W', repeat=num_prompts_per_image) + latents = einops.repeat(latents, "B C H W -> (repeat B) C H W", repeat=num_prompts_per_image) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - def prepare_image_clip_latents(self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None): + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): # Prepare latents for the CLIP embedded image. shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) if isinstance(generator, list) and len(generator) != batch_size: @@ -719,7 +737,7 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, 'B L D -> (repeat B) L D', repeat=num_prompts_per_image) + latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_prompts_per_image) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler @@ -887,16 +905,16 @@ class UniDiffuserPipeline(DiffusionPipeline): img_out = self._combine(img_vae_out, img_clip_out) return img_out - + def check_latents_shape(self, latents_name, latents, expected_shape): latents_shape = latents.shape expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension expected_shape_str = ", ".join(str(dim) for dim in expected_shape) if len(latents_shape) != expected_num_dims: raise ValueError( - f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" - f" {latents_shape} has {len(latents_shape)} dimensions." - ) + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) for i in range(1, expected_num_dims): if latents_shape[i] != expected_shape[i - 1]: raise ValueError( @@ -960,13 +978,11 @@ class UniDiffuserPipeline(DiffusionPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) - + if mode == "img2text": if image is None: - raise ValueError( - "`img2text` mode requires an image to be provided." - ) - + raise ValueError("`img2text` mode requires an image to be provided.") + # Check provided latents latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor @@ -990,27 +1006,27 @@ class UniDiffuserPipeline(DiffusionPipeline): latents_dim = img_vae_dim + self.image_encoder_hidden_size + text_dim latents_expected_shape = (latents_dim,) self.check_latents_shape("latents", latents, latents_expected_shape) - + # Check individual latent shapes, if present if prompt_latents_available: prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) - + if vae_latents_available: vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) - + if clip_latents_available: clip_latents_expected_shape = (1, self.image_encoder_hidden_size) self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) - + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: if vae_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." ) - + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: raise ValueError( @@ -1076,12 +1092,12 @@ class UniDiffuserPipeline(DiffusionPipeline): `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) - and `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. num_prompts_per_image (`int`, *optional*, defaults to 1): - The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) - and `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 405da9088b..c622c682aa 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -321,7 +321,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_text_prefix = " no no no " assert text[0][:10] == expected_text_prefix - + def test_unidiffuser_text2img_multiple_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -340,7 +340,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs["num_prompts_per_image"] = 3 image = unidiffuser_pipe(**inputs).images assert image.shape == (2, 32, 32, 3) - + def test_unidiffuser_img2text_multiple_prompts(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -360,7 +360,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): text = unidiffuser_pipe(**inputs).text assert len(text) == 3 - + def test_unidiffuser_text2img_multiple_images_with_latents(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -379,7 +379,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs["num_prompts_per_image"] = 3 image = unidiffuser_pipe(**inputs).images assert image.shape == (2, 32, 32, 3) - + def test_unidiffuser_img2text_multiple_prompts_with_latents(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()