diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 0f6dd677ac..bf4df9df93 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm +from .transformer_flux import FluxAttention, FluxAttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module): self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) if is_torch_npu_available(): + from ..attention_processor import FluxAttnProcessor2_0_NPU + deprecation_message = ( "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "should be set explicitly using the `set_attn_processor` method." @@ -230,11 +226,10 @@ class ChromaSingleTransformerBlock(nn.Module): deprecate("npu_processor", "0.34.0", deprecation_message) processor = FluxAttnProcessor2_0_NPU() else: - processor = FluxAttnProcessor2_0() + processor = FluxAttnProcessor() - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, @@ -292,16 +287,15 @@ class ChromaTransformerBlock(nn.Module): self.norm1 = ChromaAdaLayerNormZeroPruned(dim) self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, - processor=FluxAttnProcessor2_0(), + processor=FluxAttnProcessor(), qk_norm=qk_norm, eps=eps, ) @@ -376,7 +370,13 @@ class ChromaTransformerBlock(nn.Module): class ChromaTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Flux, modified for Chroma. @@ -475,106 +475,6 @@ class ChromaTransformer2DModel( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def forward( self, hidden_states: torch.Tensor,