mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix shape errors for the 'joint' and 'img2text' modes.
This commit is contained in:
@@ -611,7 +611,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
def encode_image_clip_latents(
|
||||
self,
|
||||
image,
|
||||
resolution,
|
||||
batch_size,
|
||||
num_prompts_per_image,
|
||||
dtype,
|
||||
@@ -626,8 +625,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
preprocessed_image = self.image_processor.preprocess(
|
||||
image,
|
||||
do_center_crop=True,
|
||||
crop_size=resolution,
|
||||
return_tensors="pt",
|
||||
)
|
||||
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
|
||||
@@ -794,13 +791,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
latent_height = height // self.vae_scale_factor
|
||||
latent_width = width // self.vae_scale_factor
|
||||
img_vae_dim = self.num_channels_latents * latent_height * latent_width
|
||||
text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size
|
||||
text_dim = self.text_encoder_seq_len * self.text_intermediate_dim
|
||||
|
||||
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1)
|
||||
|
||||
img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
|
||||
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
|
||||
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size))
|
||||
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim))
|
||||
return img_vae, img_clip, text
|
||||
|
||||
def _combine_joint(self, img_vae, img_clip, text):
|
||||
@@ -1268,7 +1265,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# Encode image using CLIP
|
||||
image_clip_latents = self.encode_image_clip_latents(
|
||||
image=image,
|
||||
resolution=height, # assume image is square for now...
|
||||
batch_size=batch_size,
|
||||
num_prompts_per_image=multiplier,
|
||||
dtype=prompt_embeds.dtype,
|
||||
@@ -1319,6 +1315,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
latents = prompt_embeds
|
||||
|
||||
print(f"Initial latents: {latents}")
|
||||
print(f"Initial latents shape: {latents.shape}")
|
||||
|
||||
# 7. Check that shapes of latents and image match the UNet channels.
|
||||
# TODO
|
||||
|
||||
Reference in New Issue
Block a user