diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
new file mode 100644
index 0000000000..c542bcaacc
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -0,0 +1,636 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.import_utils import is_torch_npu_available
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_processor import (
+ Attention,
+ AttentionProcessor,
+ FluxAttnProcessor2_0,
+ FluxAttnProcessor2_0_NPU,
+ FusedFluxAttnProcessor2_0,
+)
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ CombinedTimestepGuidanceTextProjEmbeddings,
+ CombinedTimestepTextProjChromaEmbeddings,
+ CombinedTimestepTextProjEmbeddings,
+ ChromaApproximator,
+ FluxPosEmbed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import (
+ AdaLayerNormContinuous,
+ AdaLayerNormContinuousPruned,
+ AdaLayerNormZero,
+ AdaLayerNormZeroPruned,
+ AdaLayerNormZeroSingle,
+ AdaLayerNormZeroSinglePruned,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
+
+
+@maybe_allow_in_graph
+class FluxSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ variant: str = "flux",
+ ):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if variant == "flux":
+ self.norm = AdaLayerNormZeroSingle(dim)
+ elif variant == "chroma":
+ self.norm = AdaLayerNormZeroSinglePruned(dim)
+ else:
+ raise ValueError(INVALID_VARIANT_ERRMSG)
+
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ if is_torch_npu_available():
+ deprecation_message = (
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
+ "should be set explicitly using the `set_attn_processor` method."
+ )
+ deprecate("npu_processor", "0.34.0", deprecation_message)
+ processor = FluxAttnProcessor2_0_NPU()
+ else:
+ processor = FluxAttnProcessor2_0()
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm="rms_norm",
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class FluxTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ qk_norm: str = "rms_norm",
+ eps: float = 1e-6,
+ variant: str = "flux",
+ ):
+ super().__init__()
+
+ if variant == "flux":
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+ elif variant == "chroma":
+ self.norm1 = AdaLayerNormZeroPruned(dim)
+ self.norm1_context = AdaLayerNormZeroPruned(dim)
+ else:
+ raise ValueError(INVALID_VARIANT_ERRMSG)
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=FluxAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb_txt
+ )
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class FluxTransformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+):
+ """
+ The Transformer model introduced in Flux.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `19`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `38`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ pooled_projection_dim (`int`, defaults to `768`):
+ The number of dimensions to use for the pooled projection.
+ guidance_embeds (`bool`, defaults to `False`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ out_channels: Optional[int] = None,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
+ variant: str = "flux",
+ approximator_in_factor: int = 16,
+ approximator_hidden_dim: int = 5120,
+ approximator_layers: int = 5,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+
+ if variant == "flux":
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+ elif variant == "chroma":
+ self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
+ factor=approximator_in_factor,
+ hidden_dim=approximator_hidden_dim,
+ out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
+ embedding_dim=self.inner_dim,
+ n_layers=approximator_layers,
+ )
+ self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
+ else:
+ raise ValueError(INVALID_VARIANT_ERRMSG)
+
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ variant=variant,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ variant=variant,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
+ self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ @property
+ def is_chroma(self) -> bool:
+ return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ is_chroma = self.is_chroma
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ if not is_chroma:
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ else:
+ input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
+ pooled_temb = self.distilled_guidance_layer(input_vec)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if is_chroma:
+ img_offset = 3 * len(self.single_transformer_blocks)
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
+ img_modulation = img_offset + 6 * index_block
+ text_modulation = txt_offset + 6 * index_block
+ temb = torch.cat(
+ (
+ pooled_temb[:, img_modulation : img_modulation + 6],
+ pooled_temb[:, text_modulation : text_modulation + 6],
+ ),
+ dim=1,
+ )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if is_chroma:
+ start_idx = 3 * index_block
+ temb = pooled_temb[:, start_idx : start_idx + 3]
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ if is_chroma:
+ temb = pooled_temb[:, -2:]
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)