1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Remove dependency on einops by refactoring einops operations to pure torch operations.

This commit is contained in:
Daniel Gu
2023-05-11 03:27:33 -07:00
parent fa9e3879e8
commit 19a20a55b9

View File

@@ -2,7 +2,6 @@ import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import einops
import numpy as np
import PIL
import torch
@@ -693,7 +692,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_images_per_prompt)
# latents is assumed to have shace (B, L, D)
latents = latents.repeat(num_images_per_prompt, 1, 1)
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
@@ -729,7 +729,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = einops.repeat(latents, "B C H W -> (repeat B) C H W", repeat=num_prompts_per_image)
# latents is assumed to have shape (B, C, H, W)
latents = latents.repeat(num_prompts_per_image, 1, 1, 1)
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
@@ -750,7 +751,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_prompts_per_image)
# latents is assumed to have shape (B, L, D)
latents = latents.repeat(num_prompts_per_image, 1, 1)
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
@@ -762,16 +764,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
and (B, 1, clip_img_dim)
"""
batch_size = x.shape[0]
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
img_vae_dim = self.num_channels_latents * latent_height * latent_width
img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_hidden_size], dim=1)
img_vae = einops.rearrange(
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
img_vae = torch.reshape(
img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)
)
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size))
return img_vae, img_clip
def _combine(self, img_vae, img_clip):
@@ -779,8 +782,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim).
"""
img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)")
img_clip = einops.rearrange(img_clip, "B L D -> B (L D)")
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
return torch.concat([img_vae, img_clip], dim=-1)
def _split_joint(self, x, height, width):
@@ -789,6 +792,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is
of shape (B, text_seq_len, text_dim).
"""
batch_size = x.shape[0]
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
img_vae_dim = self.num_channels_latents * latent_height * latent_width
@@ -796,11 +800,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1)
img_vae = einops.rearrange(
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
img_vae = torch.reshape(
img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)
)
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size)
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size))
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size))
return img_vae, img_clip, text
def _combine_joint(self, img_vae, img_clip, text):
@@ -809,9 +813,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B,
C * H * W + L_img * clip_img_dim + L_text * text_dim).
"""
img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)")
img_clip = einops.rearrange(img_clip, "B L D -> B (L D)")
text = einops.rearrange(text, "B L D -> B (L D)")
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
text = torch.reshape(text, (text.shape[0], -1))
return torch.concat([img_vae, img_clip, text], dim=-1)
def _get_noise_pred(