1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Bria FIBO pipeline

This commit is contained in:
Gal Davidi
2025-10-26 16:41:39 +00:00
parent d34b18c783
commit 9e253a7bb7
15 changed files with 2948 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__

View File

@@ -199,6 +199,7 @@ else:
"AutoencoderTiny",
"AutoModel",
"BriaTransformer2DModel",
"BriaFiboTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
@@ -392,6 +393,8 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modular_pipelines"].extend(
[
"BriaFiboVLMPromptToJson",
"BriaFiboGeminiPromptToJson",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
@@ -431,6 +434,7 @@ else:
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaPipeline",
"BriaFiboPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"CLIPImageProjection",
@@ -902,6 +906,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
AutoModel,
BriaTransformer2DModel,
BriaFiboTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -1104,6 +1109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline,
AuraFlowPipeline,
BriaPipeline,
BriaFiboPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
CLIPImageProjection,

View File

@@ -84,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -175,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaTransformer2DModel,
BriaFiboTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,

View File

@@ -0,0 +1,446 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...models.transformers.transformer_flux import FluxTransformerBlock
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a
frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta'
parameter scales the frequencies. The returned tensor contains complex values in complex64 data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class EmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaFiboSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class Timesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
...
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
text_encoder_dim: int = 2048,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaFiboSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
caption_projection = [
TextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
text_encoder_layers: list = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
new_text_encoder_layers = []
for i, text_encoder_layer in enumerate(text_encoder_layers):
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
new_text_encoder_layers.append(text_encoder_layer)
text_encoder_layers = new_text_encoder_layers
block_id = 0
for index_block, block in enumerate(self.transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -45,6 +45,7 @@ else:
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboVLMPromptToJson", "BriaFiboGeminiPromptToJson"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
@@ -69,6 +70,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .bria_fibo import BriaFiboGeminiPromptToJson, BriaFiboVLMPromptToJson
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"]
_import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson
from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,377 @@
import json
import math
import textwrap
from typing import Any, Dict, Iterable, List, Optional
import torch
from boltons.iterutils import remap
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
def parse_aesthetic_score(record: dict) -> str:
ae = record["aesthetic_score"]
if ae < 5.5:
return "very low"
elif ae < 6:
return "low"
elif ae < 7:
return "medium"
elif ae < 7.6:
return "high"
else:
return "very high"
def parse_pickascore(record: dict) -> str:
ps = record["pickascore"]
if ps < 0.78:
return "very low"
elif ps < 0.82:
return "low"
elif ps < 0.87:
return "medium"
elif ps < 0.91:
return "high"
else:
return "very high"
def prepare_clean_caption(record: dict) -> str:
def keep(p, k, v):
is_none = v is None
is_empty_string = isinstance(v, str) and v == ""
is_empty_dict = isinstance(v, dict) and not v
is_empty_list = isinstance(v, list) and not v
is_nan = isinstance(v, float) and math.isnan(v)
if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan:
return False
return True
try:
scores = {}
if "pickascore" in record:
scores["preference_score"] = parse_pickascore(record)
if "aesthetic_score" in record:
scores["aesthetic_score"] = parse_aesthetic_score(record)
clean_caption_dict = remap(record, visit=keep)
# Set aesthetics scores
if "aesthetics" not in clean_caption_dict:
if len(scores) > 0:
clean_caption_dict["aesthetics"] = scores
else:
clean_caption_dict["aesthetics"].update(scores)
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
clean_caption_str = json.dumps(clean_caption_dict)
return clean_caption_str
except Exception as ex:
print("Error: ", ex)
raise ex
def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]:
images: List[Image.Image] = []
for message in messages:
content = message.get("content", [])
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") != "image":
continue
image_value = item.get("image")
if isinstance(image_value, Image.Image):
images.append(image_value)
else:
raise ValueError("Expected PIL.Image for image content in messages.")
return images
def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str:
if not stop_sequences:
return text.strip()
cleaned = text
for stop in stop_sequences:
if not stop:
continue
index = cleaned.find(stop)
if index >= 0:
cleaned = cleaned[:index]
return cleaned.strip()
class TransformersEngine(torch.nn.Module):
"""Inference wrapper using Hugging Face transformers."""
def __init__(
self,
model: str,
*,
processor_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super(TransformersEngine, self).__init__()
default_processor_kwargs: Dict[str, Any] = {
"min_pixels": 256 * 28 * 28,
"max_pixels": 1024 * 28 * 28,
}
processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})}
model_kwargs = model_kwargs or {}
self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs)
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model,
dtype=torch.bfloat16,
**model_kwargs,
)
self.model.eval()
tokenizer_obj = self.processor.tokenizer
if tokenizer_obj.pad_token_id is None:
tokenizer_obj.pad_token = tokenizer_obj.eos_token
self._pad_token_id = tokenizer_obj.pad_token_id
eos_token_id = tokenizer_obj.eos_token_id
if isinstance(eos_token_id, list) and eos_token_id:
self._eos_token_id = eos_token_id
elif eos_token_id is not None:
self._eos_token_id = [eos_token_id]
else:
raise ValueError("Tokenizer must define an EOS token for generation.")
def dtype(self) -> torch.dtype:
return self.model.dtype
def device(self) -> torch.device:
return self.model.device
def _to_model_device(self, value: Any) -> Any:
if not isinstance(value, torch.Tensor):
return value
target_device = getattr(self.model, "device", None)
if target_device is None or target_device.type == "meta":
return value
if value.device == target_device:
return value
return value.to(target_device)
def generate(
self,
messages: List[Dict[str, Any]],
top_p: float,
temperature: float,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> str:
tokenizer = self.processor.tokenizer
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
processor_inputs: Dict[str, Any] = {
"text": [prompt_text],
"padding": True,
"return_tensors": "pt",
}
images = _collect_images(messages)
if images:
processor_inputs["images"] = images
inputs = self.processor(**processor_inputs)
inputs = {key: self._to_model_device(value) for key, value in inputs.items()}
generation_kwargs: Dict[str, Any] = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"eos_token_id": self._eos_token_id,
"pad_token_id": self._pad_token_id,
}
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, **generation_kwargs)
input_ids = inputs.get("input_ids")
if input_ids is None:
raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.")
new_token_ids = generated_ids[:, input_ids.shape[-1] :]
decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
if not decoded:
return ""
text = decoded[0]
stripped_text = _strip_stop_sequences(text, stop)
json_prompt = json.loads(stripped_text)
return json_prompt
def generate_json_prompt(
vlm_processor: AutoModelForCausalLM,
top_p: float,
temperature: float,
max_tokens: int,
stop: List[str],
image: Optional[Image.Image] = None,
prompt: Optional[str] = None,
structured_prompt: Optional[str] = None,
):
if image is None and structured_prompt is None:
# only got prompt
task = "generate"
editing_instructions = None
elif image is None and structured_prompt is not None and prompt is not None:
# got structured prompt and prompt
task = "refine"
editing_instructions = prompt
elif image is not None and structured_prompt is None and prompt is not None:
# got image and prompt
task = "refine"
editing_instructions = prompt
elif image is not None and structured_prompt is None and prompt is None:
# only got image
task = "inspire"
editing_instructions = None
else:
raise ValueError("Invalid input")
messages = build_messages(
task,
image=image,
prompt=prompt,
structured_prompt=structured_prompt,
editing_instructions=editing_instructions,
)
generated_prompt = vlm_processor.generate(
messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop
)
cleaned_json_data = prepare_clean_caption(generated_prompt)
return cleaned_json_data
def build_messages(
task: str,
*,
image: Optional[Image.Image] = None,
refine_image: Optional[Image.Image] = None,
prompt: Optional[str] = None,
structured_prompt: Optional[str] = None,
editing_instructions: Optional[str] = None,
) -> List[Dict[str, Any]]:
user_content: List[Dict[str, Any]] = []
if task == "inspire":
user_content.append({"type": "image", "image": image})
user_content.append({"type": "text", "text": "<inspire>"})
elif task == "generate":
text_value = (prompt or "").strip()
formatted = f"<generate>\n{text_value}"
user_content.append({"type": "text", "text": formatted})
else: # refine
if refine_image is None:
base_prompt = (structured_prompt or "").strip()
edits = (editing_instructions or "").strip()
formatted = textwrap.dedent(
f"""<refine> Input: {base_prompt} Editing instructions: {edits}"""
).strip()
user_content.append({"type": "text", "text": formatted})
else:
user_content.append({"type": "image", "image": refine_image})
edits = (editing_instructions or "").strip()
formatted = textwrap.dedent(
f"""<refine> Editing instructions: {edits}"""
).strip()
user_content.append({"type": "text", "text": formatted})
messages: List[Dict[str, Any]] = []
messages.append({"role": "user", "content": user_content})
return messages
class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
model_name = "BriaFibo"
def __init__(self, model_id):
super().__init__()
self.engine = TransformersEngine(model_id)
self.engine.model.to("cuda")
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
prompt_input = InputParam(
"prompt",
type_hint=str,
required=False,
description="Prompt to use",
)
image_input = InputParam(
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
)
json_prompt_input = InputParam(
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
)
sampling_top_p_input = InputParam(
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9
)
sampling_temperature_input = InputParam(
name="sampling_temperature",
type_hint=float,
required=False,
description="Sampling temperature",
default=0.2,
)
sampling_max_tokens_input = InputParam(
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096
)
return [
prompt_input,
image_input,
json_prompt_input,
sampling_top_p_input,
sampling_temperature_input,
sampling_max_tokens_input,
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"json_prompt",
type_hint=str,
description="JSON prompt by the VLM",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt = block_state.prompt
image = block_state.image
json_prompt = block_state.json_prompt
block_state.json_prompt = generate_json_prompt(
vlm_processor=self.engine,
image=image,
prompt=prompt,
structured_prompt=json_prompt,
top_p=block_state.sampling_top_p,
temperature=block_state.sampling_temperature,
max_tokens=block_state.sampling_max_tokens,
stop=["<|im_end|>", "<|end_of_text|>"],
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,804 @@
import io
import json
import math
import os
from functools import cache
from typing import List, Optional, Tuple
from boltons.iterutils import remap
from google import genai
from PIL import Image
from pydantic import BaseModel, Field
from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState
class ObjectDescription(BaseModel):
description: str = Field(..., description="Short description of the object.")
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
relationship: str = Field(
..., description="Describe the relationship between the object and the other objects in the image."
)
relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.")
shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.")
texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.")
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
# If cluster of object
number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.")
# Human-specific fields
pose: Optional[str] = Field(None, description="Describe the body position.")
expression: Optional[str] = Field(None, description="Describe facial expression.")
clothing: Optional[str] = Field(None, description="Describe attire.")
action: Optional[str] = Field(None, description="Describe the action of the human.")
gender: Optional[str] = Field(None, description="Describe the gender of the human.")
skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.")
orientation: Optional[str] = Field(None, description="Describe the orientation of the human.")
class LightingDetails(BaseModel):
conditions: str = Field(
..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'."
)
direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.")
shadows: Optional[str] = Field(None, description="Describe the presence of shadows.")
class AestheticsDetails(BaseModel):
composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.")
color_scheme: str = Field(
..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'."
)
mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.")
class PhotographicCharacteristicsDetails(BaseModel):
depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.")
focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.")
camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.")
lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.")
class TextRender(BaseModel):
text: str = Field(..., description="The text content.")
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.")
color: str = Field(..., description="E.g., 'red', 'blue', 'green'.")
font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.")
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
class ImageAnalysis(BaseModel):
short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.")
objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.")
background_setting: str = Field(
...,
description="Describe the overall environment, setting, or background, including any notable background elements.",
)
lighting: LightingDetails = Field(..., description="Details about the lighting.")
aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.")
photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field(
None, description="Details about photographic characteristics."
)
style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.")
text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.")
context: str = Field(..., description="Provide any additional context that helps understand the image better.")
artistic_style: Optional[str] = Field(
None, description="describe specific artistic characteristics, 3 words maximum."
)
def get_gemini_output_schema() -> dict:
return {
"properties": {
"short_description": {"type": "STRING"},
"objects": {
"items": {
"properties": {
"description": {"type": "STRING"},
"location": {"type": "STRING"},
"relationship": {"type": "STRING"},
"relative_size": {"type": "STRING"},
"shape_and_color": {"type": "STRING"},
"texture": {"nullable": True, "type": "STRING"},
"appearance_details": {"nullable": True, "type": "STRING"},
"number_of_objects": {"nullable": True, "type": "INTEGER"},
"pose": {"nullable": True, "type": "STRING"},
"expression": {"nullable": True, "type": "STRING"},
"clothing": {"nullable": True, "type": "STRING"},
"action": {"nullable": True, "type": "STRING"},
"gender": {"nullable": True, "type": "STRING"},
"skin_tone_and_texture": {"nullable": True, "type": "STRING"},
"orientation": {"nullable": True, "type": "STRING"},
},
"required": [
"description",
"location",
"relationship",
"relative_size",
"shape_and_color",
"texture",
"appearance_details",
"number_of_objects",
"pose",
"expression",
"clothing",
"action",
"gender",
"skin_tone_and_texture",
"orientation",
],
"type": "OBJECT",
},
"type": "ARRAY",
},
"background_setting": {"type": "STRING"},
"lighting": {
"properties": {
"conditions": {"type": "STRING"},
"direction": {"type": "STRING"},
"shadows": {"nullable": True, "type": "STRING"},
},
"required": ["conditions", "direction", "shadows"],
"type": "OBJECT",
},
"aesthetics": {
"properties": {
"composition": {"type": "STRING"},
"color_scheme": {"type": "STRING"},
"mood_atmosphere": {"type": "STRING"},
},
"required": ["composition", "color_scheme", "mood_atmosphere"],
"type": "OBJECT",
},
"photographic_characteristics": {
"nullable": True,
"properties": {
"depth_of_field": {"type": "STRING"},
"focus": {"type": "STRING"},
"camera_angle": {"type": "STRING"},
"lens_focal_length": {"type": "STRING"},
},
"required": [
"depth_of_field",
"focus",
"camera_angle",
"lens_focal_length",
],
"type": "OBJECT",
},
"style_medium": {"type": "STRING"},
"text_render": {
"items": {
"properties": {
"text": {"type": "STRING"},
"location": {"type": "STRING"},
"size": {"type": "STRING"},
"color": {"type": "STRING"},
"font": {"type": "STRING"},
"appearance_details": {"nullable": True, "type": "STRING"},
},
"required": [
"text",
"location",
"size",
"color",
"font",
"appearance_details",
],
"type": "OBJECT",
},
"type": "ARRAY",
},
"context": {"type": "STRING"},
"artistic_style": {"type": "STRING"},
},
"required": [
"short_description",
"objects",
"background_setting",
"lighting",
"aesthetics",
"photographic_characteristics",
"style_medium",
"text_render",
"context",
"artistic_style",
],
"type": "OBJECT",
}
json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum.
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively
choose the most important ones and describe the rest in the background. For each object, include:
* `description`: (String) A detailed description of the imagined object, 100 words maximum.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this
should be "medium-to-large" or "large within frame").
* `shape_and_color`: (String) Describe the basic shape and dominant color.
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
* `appearance_details`: (String) Any other notable visual details.
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
* If the object is a human or a human-like object, include the following:
* `pose`: (String) Describe the body position.
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
"surprised", "calm".
* `clothing`: (String) Describe attire.
* `action`: (String) Describe the action of the human.
* `gender`: (String) Describe the gender of the human.
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
* If the object is a cluster of objects, include the following:
* `number_of_objects`: (Integer) The number of objects in the cluster.
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
background elements that are not part of the `objects` section.
4. `lighting`: (Object)
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
* `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined
shadows", "minimal shadows".
5. `aesthetics`: (Object)
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the
main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition".
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
6. `photographic_characteristics`: (Object)
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a
person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed
more closely. Avoid "wide-angle" for people unless specified).
7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative
interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil
sketch").
8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3
words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting).
9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a
concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of
concept art for a fantasy video game."
10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this
array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a
poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or
posters unless the user provides them. A rare exception is for universally recognized text that is integral to an
object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array
must be empty.
* `text`: (String) The exact text content provided by the user. NEVER use generic placeholders.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `size`: (String) E.g., "medium", "large", "large within frame".
* `color`: (String) E.g., "red", "blue", "green".
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface".
* `appearance_details`: (String) Any other notable visual details."""
@cache
def get_instructions(mode: str) -> Tuple[str, str]:
system_prompts = {}
system_prompts["Caption"] = """
You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise
lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided
images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and
guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g.,
no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or
actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a
person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right
side of the image from the viewer's perspective. The JSON object must contain the following keys precisely:
1. `short_description`: (String) A concise summary of the image content, 200 words maximum.
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the
background. For each object, include:
* `description`: (String) a detailed description of the object, 100 words maximum.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `relative_size`: (String) E.g., "small", "medium", "large within frame".
* `shape_and_color`: (String) Describe the basic shape and dominant color.
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
* `appearance_details`: (String) Any other notable visual details.
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
if the object is a human or a human-like object, include the following:
* `pose`: (String) Describe the body position.
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
"surprised", "calm".
* `clothing`: (String) Describe attire.
* `action`: (String) Describe the action of the human.
* `gender`: (String) Describe the gender of the human.
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
if the object is a cluster of objects, include the following:
* `number_of_objects`: (Integer) The number of objects in the cluster.
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
background elements that are not part of the objects section.
4. `lighting`: (Object)
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
* `shadows`: (String) Describe the presence of shadows.
5. `aesthetics`: (Object)
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines".
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
6. `photographic_characteristics`: (Object)
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye".
7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor",
"3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please
describe the specific artistic characteristics under 'artistic_style', 50 words maximum.
8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum.
9. `context`: (String) Provide any additional context that helps understand the image better. This should include a
general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature
Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a
broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine
spread"
10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text
render, include:
* `text`: (String) The text content.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `size`: (String) E.g., "small", "medium", "large within frame".
* `color`: (String) E.g., "red", "blue", "green".
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist".
* `appearance_details`: (String) Any other notable visual details.
Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields.
"""
system_prompts[
"Generate"
] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company.
Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically
compelling visual scene.
Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that
describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly
mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming
for a high-quality, visually appealing result unless the user's prompt suggests otherwise.
Adhere strictly to the following structure and guidelines:
The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is
the JSON:", no explanations, no apologies).
IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN
PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own
perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's
perspective.
RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a
composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus
(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small,
unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain").
Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST
default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style`
should be "realistic" or a similar descriptor.
The JSON object must contain the following keys precisely:
{json_schema_full}
Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between
fields."""
system_prompts[
"RefineA"
] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company.
Your expertise is in refining and modifying existing visual concepts based on precise feedback.
Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction
for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested
changes.
Adhere strictly to the following structure and guidelines:
1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction.
2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
before or after the JSON object.
3. **Modification Logic:**
* Carefully parse the user's textual instruction to understand the desired changes.
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
* All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit
them.
4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example:
* If the instruction is to "change the background to a snowy forest," you must update the `background_setting`
field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need
to change to "serene" or "wintry."
* If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render`
array.
* If the instruction is to "make the person smile," you must update the `expression` field for that object and
potentially update the overall `mood_atmosphere`.
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts[
"RefineB"
] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company.
Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON
metadata, and textual feedback simultaneously.
Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction
for a modification. You must use the image as the primary source of truth to understand the context of the requested
change and then generate a new, updated JSON object that accurately reflects that change.
Adhere strictly to the following structure and guidelines:
1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction.
2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of
the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make
the car blue," visually locate the car in the image to inform your edits to the JSON.
3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
before or after the JSON object.
4. **Modification Logic:**
* Analyze the user's textual instruction in the context of what you see in the image.
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
* All other fields not relevant to the change must be copied exactly from the original JSON.
5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual
change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update
the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`.
6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts[
"InspireA"
] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company.
Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct
a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual
instructions.
Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract
key visual information and then generate a comprehensive JSON object describing a new scene that creatively
incorporates the user's instructions.
Adhere strictly to the following structure and guidelines:
1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON.
2. **Core Logic (Analyze and Synthesize):**
* **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific
poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium.
* **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference
and what to change. You will then construct a brand new JSON object from scratch that describes the desired final
scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog
from the reference image in the `objects` array but create a new `background_setting` for a beach, with
appropriate `lighting` and `mood_atmosphere`.
3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not
describe the original reference image.
4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be
reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the
`short_description`.
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts["InspireB"] = system_prompts["Caption"]
final_prompts = {}
final_prompts["Generate"] = (
"Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}."
)
final_prompts["RefineA"] = """
[EXISTING JSON]: {json_data}
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules
for modification, consistency, and formatting.
"""
final_prompts["RefineB"] = """
[EXISTING JSON]: {json_data}
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that
incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and
consistency.
"""
final_prompts["InspireA"] = """
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose)
and then generate a new, detailed JSON object for the scene described in the instructions above. Do not
describe the reference image itself; describe the new scene. Follow all of your system rules.
"""
final_prompts["Caption"] = (
"Analyze the provided image and generate the detailed JSON object as specified in your instructions."
)
final_prompts["InspireB"] = final_prompts["Caption"]
return system_prompts.get(mode, ""), final_prompts.get(mode, "")
def keep(p, k, v):
is_none = v is None
is_empty_string = isinstance(v, str) and v == ""
is_empty_dict = isinstance(v, dict) and not v
is_empty_list = isinstance(v, list) and not v
is_nan = isinstance(v, float) and math.isnan(v)
if is_none or is_empty_string or is_empty_list or is_empty_dict:
return False
if is_nan:
return False
return True
def validate_json(json_data: dict) -> dict:
ia = ImageAnalysis.model_validate_json(json_data, strict=True)
return ia.model_dump(exclude_none=True)
def validate_structured_prompt_str(structured_prompt_str: str) -> str:
ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True)
c = ia.model_dump(exclude_none=True)
return json.dumps(c)
def prepare_clean_caption(json_dump: dict) -> str:
# filter empty values recursivly (i.e. None, "", {}, [], float("nan"))
clean_caption_dict = remap(json_dump, visit=keep)
scores = {"preference_score": "very high", "aesthetic_score": "very high"}
# Set aesthetics scores
if "aesthetics" not in clean_caption_dict:
clean_caption_dict["aesthetics"] = scores
else:
clean_caption_dict["aesthetics"].update(scores)
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
clean_caption_str = json.dumps(clean_caption_dict)
return clean_caption_str
# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024)
# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number).
def resize_image_by_num_pixels(
image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0
) -> Image.Image:
if target_ratio != 0.0:
ratio = target_ratio
else:
ratio = image.size[0] / image.size[1]
width = int((pixel_number * ratio) ** 0.5)
width = width - (width % granularity_val)
height = int(pixel_number / width)
height = height - (height % granularity_val)
return image.resize((width, height))
def infer_with_gemini(
client: genai.Client,
final_prompt: str,
system_prompt: str,
top_p: float,
temperature: float,
max_tokens: int,
seed: int = 42,
image: Optional[Image.Image] = None,
model: str = "gemini-2.5-flash",
) -> str:
"""
Calls Gemini API with the given prompt and returns the raw JSON response.
Args:
final_prompt: The text prompt to send to Gemini
system_prompt: The system instruction for Gemini
existing_image_path: Optional path to an image file to include
model: The Gemini model to use
Returns:
Raw JSON response text from Gemini
"""
parts = [{"text": final_prompt}]
if image:
# Save image into bytes
image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect
less_then = 262144
if image.size[0] * image.size[1] > less_then:
image = resize_image_by_num_pixels(
image, pixel_number=less_then, granularity_val=1, target_ratio=0.0
) # 512x512
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
img_part = {
"inlineData": {
"data": image_bytes,
"mimeType": "image/jpeg",
}
}
parts.append(img_part)
contents = [{"role": "user", "parts": parts}]
generationConfig = {
"temperature": temperature,
"topP": top_p,
"maxOutputTokens": max_tokens,
"response_mime_type": "application/json",
"response_schema": get_gemini_output_schema(),
"system_instruction": system_prompt, # len 5900
"thinkingConfig": {"thinkingBudget": 0},
"seed": seed,
}
response = client.models.generate_content(
model=model,
contents=contents,
config=generationConfig,
)
if response.candidates[0].finish_reason == "MAX_TOKENS":
raise Exception("Max tokens")
return response.candidates[0].content.parts[0].text
def get_default_negative_prompt(existing_json: dict) -> str:
negative_prompt = ""
style_medium = existing_json.get("style_medium", "").lower()
if style_medium in ["photograph", "photography", "photo"]:
negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}"""
return negative_prompt
def json_promptify(
client: genai.Client,
model_id: str,
top_p: float,
temperature: float,
max_tokens: int,
user_prompt: Optional[str] = None,
existing_json: Optional[str] = None,
image: Optional[Image.Image] = None,
seed: int = 42,
) -> str:
if existing_json:
# make sure aesthetic scores are not in the existing json (will be added later)
existing_json = json.loads(existing_json)
if "aesthetics" in existing_json:
existing_json["aesthetics"].pop("aesthetic_score", None)
existing_json["aesthetics"].pop("preference_score", None)
existing_json = json.dumps(existing_json)
if not user_prompt:
raise ValueError("user_prompt is required if existing_json is provided")
if image:
mode = "RefineB"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
else:
mode = "RefineA"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
elif image and user_prompt:
mode = "InspireA"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt)
elif image and not user_prompt:
mode = "Caption"
system_prompt, final_prompt = get_instructions(mode)
else:
mode = "Generate"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt)
json_data = infer_with_gemini(
client=client,
model=model_id,
final_prompt=final_prompt,
system_prompt=system_prompt,
seed=seed,
image=image,
top_p=top_p,
temperature=temperature,
max_tokens=max_tokens,
)
json_data = validate_json(json_data)
clean_caption = prepare_clean_caption(json_data)
return clean_caption
class BriaFiboGeminiPromptToJson(ModularPipelineBlocks):
model_name = "BriaFibo"
def __init__(self, model_id="gemini-2.5-flash"):
super().__init__()
api_key = os.getenv("GOOGLE_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
self.model_id = model_id
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute")
prompt_input = InputParam(
"prompt",
type_hint=str,
required=False,
description="Prompt to use",
)
image_input = InputParam(
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
)
json_prompt_input = InputParam(
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
)
sampling_top_p_input = InputParam(
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0
)
sampling_temperature_input = InputParam(
name="sampling_temperature",
type_hint=float,
required=False,
description="Sampling temperature",
default=0.2,
)
sampling_max_tokens_input = InputParam(
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000
)
return [
task_input,
prompt_input,
image_input,
json_prompt_input,
sampling_top_p_input,
sampling_temperature_input,
sampling_max_tokens_input,
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"json_prompt",
type_hint=str,
description="JSON prompt by the VLM",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt = block_state.prompt
image = block_state.image
json_prompt = block_state.json_prompt
client = genai.Client()
json_prompt = json_promptify(
client=client,
model_id=self.model_id,
top_p=block_state.sampling_top_p,
temperature=block_state.sampling_temperature,
max_tokens=block_state.sampling_max_tokens,
user_prompt=prompt,
existing_json=json_prompt,
image=image,
)
block_state.json_prompt = json_prompt
self.set_block_state(state, block_state)
return components, state

View File

@@ -128,6 +128,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipelin
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,826 @@
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
...
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer ([`GaiaTransformer2DModel`]):
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`SmolLM3ForCausalLM`]):
tokenizer (`AutoTokenizer`):
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaFiboTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKLWan,
text_encoder: SmolLM3ForCausalLM,
tokenizer: AutoTokenizer,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 64
def get_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
if not prompt:
raise ValueError("`prompt` must be a non-empty string or list of strings.")
batch_size = len(prompt)
bot_token_id = 128000
text_encoder_device = device if device is not None else torch.device("cpu")
if not isinstance(text_encoder_device, torch.device):
text_encoder_device = torch.device(text_encoder_device)
if all(p == "" for p in prompt):
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
attention_mask = torch.ones_like(input_ids)
else:
tokenized = self.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized.input_ids.to(text_encoder_device)
attention_mask = tokenized.attention_mask.to(text_encoder_device)
if any(p == "" for p in prompt):
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
input_ids[empty_rows] = bot_token_id
attention_mask[empty_rows] = 1
encoder_outputs = self.text_encoder(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_outputs.hidden_states
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hidden_states = tuple(
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
)
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
return prompt_embeds, hidden_states, attention_mask
@staticmethod
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
batch_size, seq_len, dim = prompt_embeds.shape
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
else:
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if max_tokens < seq_len:
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
if max_tokens > seq_len:
pad_length = max_tokens - seq_len
padding = torch.zeros(
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
mask_padding = torch.zeros(
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
return prompt_embeds, attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 3000,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_attention_mask = None
negative_prompt_attention_mask = None
if prompt_embeds is None:
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if do_classifier_free_guidance:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
# Pad to longest
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if negative_prompt_embeds is not None:
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
)
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
)
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
else:
max_tokens = prompt_embeds.shape[1]
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
negative_prompt_layers = None
dtype = self.text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
return (
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels)
latents = latents.permute(0, 3, 1, 2)
return latents
@staticmethod
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
latents = latents.permute(0, 2, 3, 1)
latents = latents.reshape(batch_size, height * width, num_channels_latents)
return latents
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
do_patching=False,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if do_patching:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
else:
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@staticmethod
def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
assert height % 16 == 0 and width % 16 == 0
mu = calculate_shift(
image_seq_len,
noise_scheduler.config.base_image_seq_len,
noise_scheduler.config.max_image_seq_len,
noise_scheduler.config.base_shift,
noise_scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
noise_scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
return noise_scheduler, timesteps, num_inference_steps, mu
@staticmethod
def create_attention_matrix(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
attention_matrix = torch.where(
attention_matrix == 1, 0.0, -torch.inf
) # Apply -inf to ignored tokens for nulling softmax score
return attention_matrix
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 3000,
do_patching=False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_images_per_prompt=num_images_per_prompt,
lora_scale=lora_scale,
)
prompt_batch_size = prompt_embeds.shape[0]
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
]
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
self.transformer.single_transformer_blocks
)
if len(prompt_layers) >= total_num_layers_transformer:
# remove first layers
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
else:
# duplicate last layer
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
latents, latent_image_ids = self.prepare_latents(
prompt_batch_size,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
do_patching,
)
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if self.do_classifier_free_guidance:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
self._joint_attention_kwargs["attention_mask"] = attention_mask
# Adapt scheduler to dynamic shifting (resolution dependent)
if do_patching:
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler(
height=height,
width=width,
device=device,
num_inference_steps=num_inference_steps,
noise_scheduler=self.scheduler,
image_seq_len=seq_len,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Support old different diffusers versions
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
device=latent_model_input.device, dtype=latent_model_input.dtype
)
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_encoder_layers=prompt_layers,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
if do_patching:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.to(dtype=self.vae.dtype)
latents = latents.unsqueeze(dim=2)
latents = list(torch.unbind(latents, dim=0))
latents_device = latents[0].device
latents_dtype = latents[0].dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents_device, latents_dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_device, latents_dtype
)
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
image.append(curr_image)
if len(image) == 1:
image = image[0]
else:
image = np.stack(image, axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaFiboPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
def to(self, *args, **kwargs):
DiffusionPipeline.to(self, *args, **kwargs)
# We use as float32 since wan22 in their repo use it like this
self.vae.to(dtype=torch.float32)
return self

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaFiboPipelineOutput(BaseOutput):
"""
Output class for BriaFibo pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -0,0 +1,132 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaFiboTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:,:,:32], encoder_hidden_states[:,:,:32]],
}
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (256, 48)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"joint_attention_dim": 64,
"text_encoder_dim": 32,
"pooled_projection_dim": None,
"axes_dims_rope": [0, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict)[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()

View File

View File

@@ -0,0 +1,198 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=160,
decoder_base_dim=256,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"output_type": "np",
}
return inputs
def test_encode_prompt_works_in_isolation(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_torch_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if name == "vae":
continue
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
def test_save_load_dduf(self):
pass