mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Remove dependency on einops by refactoring einops operations to pure torch operations.
This commit is contained in:
@@ -2,7 +2,6 @@ import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
@@ -693,7 +692,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_images_per_prompt)
|
||||
# latents is assumed to have shace (B, L, D)
|
||||
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@@ -729,7 +729,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = einops.repeat(latents, "B C H W -> (repeat B) C H W", repeat=num_prompts_per_image)
|
||||
# latents is assumed to have shape (B, C, H, W)
|
||||
latents = latents.repeat(num_prompts_per_image, 1, 1, 1)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@@ -750,7 +751,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = einops.repeat(latents, "B L D -> (repeat B) L D", repeat=num_prompts_per_image)
|
||||
# latents is assumed to have shape (B, L, D)
|
||||
latents = latents.repeat(num_prompts_per_image, 1, 1)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@@ -762,16 +764,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
|
||||
and (B, 1, clip_img_dim)
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
latent_height = height // self.vae_scale_factor
|
||||
latent_width = width // self.vae_scale_factor
|
||||
img_vae_dim = self.num_channels_latents * latent_height * latent_width
|
||||
|
||||
img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_hidden_size], dim=1)
|
||||
|
||||
img_vae = einops.rearrange(
|
||||
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
|
||||
img_vae = torch.reshape(
|
||||
img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)
|
||||
)
|
||||
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
|
||||
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size))
|
||||
return img_vae, img_clip
|
||||
|
||||
def _combine(self, img_vae, img_clip):
|
||||
@@ -779,8 +782,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
|
||||
clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim).
|
||||
"""
|
||||
img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)")
|
||||
img_clip = einops.rearrange(img_clip, "B L D -> B (L D)")
|
||||
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
|
||||
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
|
||||
return torch.concat([img_vae, img_clip], dim=-1)
|
||||
|
||||
def _split_joint(self, x, height, width):
|
||||
@@ -789,6 +792,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is
|
||||
of shape (B, text_seq_len, text_dim).
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
latent_height = height // self.vae_scale_factor
|
||||
latent_width = width // self.vae_scale_factor
|
||||
img_vae_dim = self.num_channels_latents * latent_height * latent_width
|
||||
@@ -796,11 +800,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1)
|
||||
|
||||
img_vae = einops.rearrange(
|
||||
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
|
||||
img_vae = torch.reshape(
|
||||
img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)
|
||||
)
|
||||
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
|
||||
text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size)
|
||||
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_hidden_size))
|
||||
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size))
|
||||
return img_vae, img_clip, text
|
||||
|
||||
def _combine_joint(self, img_vae, img_clip, text):
|
||||
@@ -809,9 +813,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B,
|
||||
C * H * W + L_img * clip_img_dim + L_text * text_dim).
|
||||
"""
|
||||
img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)")
|
||||
img_clip = einops.rearrange(img_clip, "B L D -> B (L D)")
|
||||
text = einops.rearrange(text, "B L D -> B (L D)")
|
||||
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
|
||||
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
|
||||
text = torch.reshape(text, (text.shape[0], -1))
|
||||
return torch.concat([img_vae, img_clip, text], dim=-1)
|
||||
|
||||
def _get_noise_pred(
|
||||
|
||||
Reference in New Issue
Block a user