1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Bria 3 2 pipeline (#12010)

* Add Bria model and pipeline to diffusers

- Introduced `BriaTransformer2DModel` and `BriaPipeline` for enhanced image generation capabilities.
- Updated import structures across various modules to include the new Bria components.
- Added utility functions and output classes specific to the Bria pipeline.
- Implemented tests for the Bria pipeline to ensure functionality and output integrity.

* with working tests

* style and quality pass

* adding docs

* add to overview

* fixes from "make fix-copies"

* Refactor transformer_bria.py and pipeline_bria.py: Introduce new EmbedND class for rotary position embedding, and enhance Timestep and TimestepProjEmbeddings classes. Add utility functions for handling negative prompts and generating original sigmas in pipeline_bria.py.

* remove redundent and duplicates tests and fix bf16
slow test

* style fixes

* small doc update

* Enhance Bria 3.2 documentation and implementation

- Updated the GitHub repository link for Bria 3.2.
- Added usage instructions for the gated model access.
- Introduced the BriaTransformerBlock and BriaAttention classes to the model architecture.
- Refactored existing classes to integrate Bria-specific components, including BriaEmbedND and BriaPipeline.
- Updated the pipeline output class to reflect Bria-specific functionality.
- Adjusted test cases to align with the new Bria model structure.

* Refactor Bria model components and update documentation

- Removed outdated inference example from Bria 3.2 documentation.
- Introduced the BriaTransformerBlock class to enhance model architecture.
- Updated attention handling to use `attention_kwargs` instead of `joint_attention_kwargs`.
- Improved import structure in the Bria pipeline to handle optional dependencies.
- Adjusted test cases to reflect changes in model dtype assertions.

* Update Bria model reference in documentation to reflect new file naming convention

* Update docs/source/en/_toctree.yml

* Refactor BriaPipeline to inherit from DiffusionPipeline instead of FluxPipeline, updating imports accordingly.

* move the __call__ func to the end of file

* Update BriaPipeline example to use bfloat16 for precision sensitivity for better result

* make style && make quality &&  make fix-copiessource

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
This commit is contained in:
galbria
2025-08-20 12:27:39 +03:00
committed by GitHub
parent 7a2b78bf0f
commit 7993be9e7f
18 changed files with 2131 additions and 0 deletions

View File

@@ -340,6 +340,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/cogvideox_transformer3d
@@ -468,6 +470,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox

View File

@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/BRIA-3.2)
## BriaTransformer2DModel
[[autodoc]] BriaTransformer2DModel

View File

@@ -0,0 +1,44 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria 3.2
Bria 3.2 is the next-generation commercial-ready text-to-image model. With just 4 billion parameters, it provides exceptional aesthetics and text rendering, evaluated to provide on par results to leading open-source models, and outperforming other licensed models.
In addition to being built entirely on licensed data, 3.2 provides several advantages for enterprise and commercial use:
- Efficient Compute - the model is X3 smaller than the equivalent models in the market (4B parameters vs 12B parameters other open source models)
- Architecture Consistency: Same architecture as 3.1—ideal for users looking to upgrade without disruption.
- Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100.
Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2).
Github repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2).
If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria 3.2 Hugging Face page](https://huggingface.co/briaai/BRIA-3.2), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__

View File

@@ -37,6 +37,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [AudioLDM2](audioldm2) | text2audio |
| [AuraFlow](auraflow) | text2image |
| [BLIP Diffusion](blip_diffusion) | text2image |
| [Bria 3.2](bria_3_2) | text2image |
| [CogVideoX](cogvideox) | text2video |
| [Consistency Models](consistency_models) | unconditional image generation |
| [ControlNet](controlnet) | text2image, image2image, inpainting |

View File

@@ -181,6 +181,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
@@ -397,6 +398,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"CLIPImageProjection",
@@ -846,6 +848,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -1032,6 +1035,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
CLIPImageProjection,

View File

@@ -144,6 +144,7 @@ def _register_attention_processors_metadata():
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
@@ -165,6 +166,13 @@ def _register_transformer_blocks_metadata():
return_encoder_hidden_states_index=None,
),
)
TransformerBlockRegistry.register(
model_class=BriaTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(

View File

@@ -76,6 +76,7 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -158,6 +159,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,

View File

@@ -17,6 +17,7 @@ if is_torch_available():
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel

View File

@@ -0,0 +1,719 @@
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# bria
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class BriaAttnProcessor:
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class BriaAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = BriaAttnProcessor
_available_processors = [
BriaAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class BriaEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class BriaTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class BriaTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = BriaTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
attention_kwargs = attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
class BriaSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = BriaAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attention_kwargs = attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
"""
The Transformer model introduced in Flux. Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = BriaEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BriaTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`BriaTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -127,6 +127,7 @@ else:
"AnimateDiffVideoToVideoPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -552,6 +553,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria"] = ["BriaPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria import BriaPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,729 @@
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models import AutoencoderKL
from ...models.transformers.transformer_bria import BriaTransformer2DModel
from ...pipelines import DiffusionPipeline
from ...pipelines.bria.pipeline_output import BriaPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
KarrasDiffusionSchedulers,
)
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import BriaPipeline
>>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
# BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
>>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
>>> for block in pipe.text_encoder.encoder.block:
... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
# BRIA's VAE is not supported in mixed precision, so we use float32.
>>> if pipe.vae.config.shift_factor == 0:
... pipe.vae.to(dtype=torch.float32)
>>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
>>> image = pipe(prompt).images[0]
>>> image.save("bria.png")
```
"""
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (type(negative_prompt) == list and negative_prompt[0] == "")
)
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
class BriaPipeline(DiffusionPipeline):
r"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
Args:
transformer ([`BriaTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Bria uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
if self.vae.config.shift_factor is None:
self.vae.config.shift_factor = 0
self.vae.to(dtype=torch.float32)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = (
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def attention_kwargs(self):
return self._attention_kwargs
@attention_kwargs.setter
def attention_kwargs(self, value):
self._attention_kwargs = value
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
tokenizer = self.tokenizer
text_encoder = self.text_encoder
device = device or text_encoder.device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
prompt_embeds_list = []
for p in prompt:
text_inputs = tokenizer(
p,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
return prompt_embeds
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
clip_value: Union[None, float] = None,
normalize: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self.attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
(prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
else:
# 4. Prepare timesteps
# Sample from training sigmas
if isinstance(self.scheduler, DDIMScheduler) or isinstance(
self.scheduler, EulerAncestralDiscreteScheduler
):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
else:
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps,
num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=self.attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if normalize:
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
if clip_value:
assert clip_value > 0
noise_pred = noise_pred.clip(-clip_value, clip_value)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaPipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaPipelineOutput(BaseOutput):
"""
Output class for Bria pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -528,6 +528,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CacheMixin(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -362,6 +362,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChromaImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,181 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
def create_bria_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": None,
"axes_dims_rope": [0, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()

View File

View File

@@ -0,0 +1,318 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKL,
BriaTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.pipelines.bria import BriaPipeline
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch_gpu,
slow,
torch_device,
)
# from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
torch.manual_seed(0)
vae = AutoencoderKL(
act_fn="silu",
block_out_channels=(32,),
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D"],
latent_channels=4,
sample_size=32,
shift_factor=0,
scaling_factor=0.13025,
use_post_quant_conv=True,
use_quant_conv=True,
force_upcast=False,
)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 16,
"width": 16,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_encode_prompt_works_in_isolation(self):
pass
def test_bria_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if name == "vae":
continue
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
def test_bria_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(16, 16), (32, 32), (64, 64)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [True, True, True])
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
torch_dtype_dict = {"transformer": torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
self.assertEqual(loaded_pipe.transformer.dtype, torch.bfloat16)
self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
torch_dtype_dict = {"default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
self.assertEqual(loaded_pipe.transformer.dtype, torch.float16)
self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
@slow
@require_torch_gpu
class BriaPipelineSlowTests(unittest.TestCase):
pipeline_class = BriaPipeline
repo_id = "briaai/BRIA-3.2"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
generator = torch.Generator(device="cpu").manual_seed(seed)
prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
).to(torch_device)
return {
"prompt_embeds": prompt_embeds,
"num_inference_steps": 2,
"guidance_scale": 0.0,
"max_sequence_length": 256,
"output_type": "np",
"generator": generator,
}
def test_bria_inference_bf16(self):
pipe = self.pipeline_class.from_pretrained(
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None
)
pipe.to(torch_device)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10].flatten()
expected_slice = np.array(
[
0.59729785,
0.6153719,
0.595112,
0.5884763,
0.59366125,
0.5795311,
0.58325,
0.58449626,
0.57737637,
0.58432233,
0.5867875,
0.57824117,
0.5819089,
0.5830988,
0.57730293,
0.57647324,
0.5769151,
0.57312685,
0.57926565,
0.5823928,
0.57783926,
0.57162863,
0.575649,
0.5745547,
0.5740556,
0.5799735,
0.57799566,
0.5715559,
0.5771242,
0.5773058,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}")