diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 2228649eb5..97908cc16d 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -611,7 +611,6 @@ class UniDiffuserPipeline(DiffusionPipeline): def encode_image_clip_latents( self, image, - resolution, batch_size, num_prompts_per_image, dtype, @@ -626,8 +625,6 @@ class UniDiffuserPipeline(DiffusionPipeline): preprocessed_image = self.image_processor.preprocess( image, - do_center_crop=True, - crop_size=resolution, return_tensors="pt", ) preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) @@ -794,13 +791,13 @@ class UniDiffuserPipeline(DiffusionPipeline): latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width - text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) - text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) return img_vae, img_clip, text def _combine_joint(self, img_vae, img_clip, text): @@ -1268,7 +1265,6 @@ class UniDiffuserPipeline(DiffusionPipeline): # Encode image using CLIP image_clip_latents = self.encode_image_clip_latents( image=image, - resolution=height, # assume image is square for now... batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, @@ -1319,6 +1315,7 @@ class UniDiffuserPipeline(DiffusionPipeline): latents = prompt_embeds print(f"Initial latents: {latents}") + print(f"Initial latents shape: {latents.shape}") # 7. Check that shapes of latents and image match the UNet channels. # TODO