diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 17bebe6e2a..30219ca6c2 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -2,7 +2,6 @@ import inspect from dataclasses import dataclass from typing import Callable, List, Optional, Union -import einops import numpy as np import PIL import torch @@ -693,7 +692,8 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_images_per_prompt) + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler @@ -729,7 +729,8 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, "B C H W -> (repeat B) C H W", repeat=num_prompts_per_image) + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler @@ -750,7 +751,8 @@ class UniDiffuserPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_prompts_per_image) + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler @@ -762,16 +764,17 @@ class UniDiffuserPipeline(DiffusionPipeline): Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) and (B, 1, clip_img_dim) """ + batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_hidden_size], dim=1) - img_vae = einops.rearrange( - img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width + img_vae = torch.reshape( + img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width) ) - img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size)) return img_vae, img_clip def _combine(self, img_vae, img_clip): @@ -779,8 +782,8 @@ class UniDiffuserPipeline(DiffusionPipeline): Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). """ - img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)") - img_clip = einops.rearrange(img_clip, "B L D -> B (L D)") + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) return torch.concat([img_vae, img_clip], dim=-1) def _split_joint(self, x, height, width): @@ -789,6 +792,7 @@ class UniDiffuserPipeline(DiffusionPipeline): img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is of shape (B, text_seq_len, text_dim). """ + batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width @@ -796,11 +800,11 @@ class UniDiffuserPipeline(DiffusionPipeline): img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1) - img_vae = einops.rearrange( - img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width + img_vae = torch.reshape( + img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width) ) - img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size) - text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size)) return img_vae, img_clip, text def _combine_joint(self, img_vae, img_clip, text): @@ -809,9 +813,9 @@ class UniDiffuserPipeline(DiffusionPipeline): clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, C * H * W + L_img * clip_img_dim + L_text * text_dim). """ - img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)") - img_clip = einops.rearrange(img_clip, "B L D -> B (L D)") - text = einops.rearrange(text, "B L D -> B (L D)") + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) return torch.concat([img_vae, img_clip, text], dim=-1) def _get_noise_pred(