1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix shape errors for the 'joint' and 'img2text' modes.

This commit is contained in:
Daniel Gu
2023-05-11 17:43:52 -07:00
parent 10e3774b8e
commit 4d656b50a0

View File

@@ -611,7 +611,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
def encode_image_clip_latents(
self,
image,
resolution,
batch_size,
num_prompts_per_image,
dtype,
@@ -626,8 +625,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
preprocessed_image = self.image_processor.preprocess(
image,
do_center_crop=True,
crop_size=resolution,
return_tensors="pt",
)
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
@@ -794,13 +791,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
img_vae_dim = self.num_channels_latents * latent_height * latent_width
text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size
text_dim = self.text_encoder_seq_len * self.text_intermediate_dim
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1)
img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size))
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim))
return img_vae, img_clip, text
def _combine_joint(self, img_vae, img_clip, text):
@@ -1268,7 +1265,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
# Encode image using CLIP
image_clip_latents = self.encode_image_clip_latents(
image=image,
resolution=height, # assume image is square for now...
batch_size=batch_size,
num_prompts_per_image=multiplier,
dtype=prompt_embeds.dtype,
@@ -1319,6 +1315,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
latents = prompt_embeds
print(f"Initial latents: {latents}")
print(f"Initial latents shape: {latents.shape}")
# 7. Check that shapes of latents and image match the UNet channels.
# TODO