diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index fc7a89c8d6..8c7180294c 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -16,13 +16,12 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch -from einops import rearrange -from einops.layers.torch import Rearrange from torch import Tensor, nn from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin from ..attention_processor import Attention, AttentionProcessor from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput @@ -134,6 +133,7 @@ class PhotonAttnProcessor2_0: attn_output = attn.to_out[1](attn_output) # dropout if present return attn_output +# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -155,7 +155,6 @@ class EmbedND(nn.Module): self.dim = dim self.theta = theta self.axes_dim = axes_dim - self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 @@ -163,7 +162,9 @@ class EmbedND(nn.Module): omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) - out = self.rope_rearrange(out) + # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2) + out = out.reshape(*out.shape[:-1], 2, 2) return out.float() def forward(self, ids: Tensor) -> Tensor: @@ -378,12 +379,20 @@ class PhotonBlock(nn.Module): img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift img_qkv = self.img_qkv_proj(img_mod) - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + B, L, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] img_q, img_k = self.qk_norm(img_q, img_k, img_v) # txt tokens proj and norm txt_kv = self.txt_kv_proj(txt) - txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + # Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + B, L, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) + txt_k, txt_v = txt_kv[0], txt_kv[1] txt_k = self.k_norm(txt_k) # compute attention @@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class PhotonTransformer2DModel(ModelMixin, ConfigMixin): +class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA scaling. @@ -689,65 +698,6 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size)