mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove einops dependency and now inherits from AttentionMixin
This commit is contained in:
@@ -16,13 +16,12 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.functional import fold, unfold
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -134,6 +133,7 @@ class PhotonAttnProcessor2_0:
|
||||
attn_output = attn.to_out[1](attn_output) # dropout if present
|
||||
|
||||
return attn_output
|
||||
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND(nn.Module):
|
||||
r"""
|
||||
N-dimensional rotary positional embedding.
|
||||
@@ -155,7 +155,6 @@ class EmbedND(nn.Module):
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
|
||||
def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
@@ -163,7 +162,9 @@ class EmbedND(nn.Module):
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = self.rope_rearrange(out)
|
||||
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
|
||||
out = out.reshape(*out.shape[:-1], 2, 2)
|
||||
return out.float()
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
@@ -378,12 +379,20 @@ class PhotonBlock(nn.Module):
|
||||
img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
|
||||
|
||||
img_qkv = self.img_qkv_proj(img_mod)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
# Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
B, L, _ = img_qkv.shape
|
||||
img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D)
|
||||
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
|
||||
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
|
||||
img_q, img_k = self.qk_norm(img_q, img_k, img_v)
|
||||
|
||||
# txt tokens proj and norm
|
||||
txt_kv = self.txt_kv_proj(txt)
|
||||
txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
|
||||
# Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
|
||||
B, L, _ = txt_kv.shape
|
||||
txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D)
|
||||
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
|
||||
txt_k, txt_v = txt_kv[0], txt_kv[1]
|
||||
txt_k = self.k_norm(txt_k)
|
||||
|
||||
# compute attention
|
||||
@@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
|
||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA
|
||||
scaling.
|
||||
@@ -689,65 +698,6 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
|
||||
txt = self.txt_in(txt)
|
||||
img = img2seq(image_latent, self.patch_size)
|
||||
|
||||
Reference in New Issue
Block a user