1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove einops dependency and now inherits from AttentionMixin

This commit is contained in:
davidb
2025-10-10 13:40:49 +00:00
committed by DavidBert
parent 25a0061d65
commit 5886925346

View File

@@ -16,13 +16,12 @@ from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import Tensor, nn
from torch.nn.functional import fold, unfold
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
@@ -134,6 +133,7 @@ class PhotonAttnProcessor2_0:
attn_output = attn.to_out[1](attn_output) # dropout if present
return attn_output
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
@@ -155,7 +155,6 @@ class EmbedND(nn.Module):
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
@@ -163,7 +162,9 @@ class EmbedND(nn.Module):
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = self.rope_rearrange(out)
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
out = out.reshape(*out.shape[:-1], 2, 2)
return out.float()
def forward(self, ids: Tensor) -> Tensor:
@@ -378,12 +379,20 @@ class PhotonBlock(nn.Module):
img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
img_qkv = self.img_qkv_proj(img_mod)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
# Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = img_qkv.shape
img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D)
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
img_q, img_k = self.qk_norm(img_q, img_k, img_v)
# txt tokens proj and norm
txt_kv = self.txt_kv_proj(txt)
txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
# Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
B, L, _ = txt_kv.shape
txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D)
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
txt_k, txt_v = txt_kv[0], txt_kv[1]
txt_k = self.k_norm(txt_k)
# compute attention
@@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA
scaling.
@@ -689,65 +698,6 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
txt = self.txt_in(txt)
img = img2seq(image_latent, self.patch_size)