1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
galbria
2025-10-27 13:04:57 +00:00
parent 371e5f511e
commit a617433ace
8 changed files with 375 additions and 1409 deletions

View File

@@ -323,6 +323,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -469,6 +471,8 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3

View File

@@ -1,18 +1,26 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, get_timestep_embedding
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...models.transformers.transformer_flux import FluxTransformerBlock
from ...utils import (
USE_PEFT_BACKEND,
logging,
@@ -20,76 +28,193 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a
frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta'
parameter scales the frequencies. The returned tensor contains complex values in complex64 data type.
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
return query, key, value, encoder_query, encoder_key, encoder_value
class EmbedND(torch.nn.Module):
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class BriaFiboAttnProcessor:
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaFiboAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [
BriaFiboAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class FIBOEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
@@ -182,7 +307,93 @@ class TextProjection(nn.Module):
return hidden_states
class Timesteps(nn.Module):
@maybe_allow_in_graph
class BriaFiboTransformerBlock(nn.Module):
# Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaFiboAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaFiboAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class FIBOTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
@@ -209,7 +420,7 @@ class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
self.time_proj = FIBOTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
@@ -258,7 +469,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
@@ -270,7 +481,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
BriaFiboTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,

View File

@@ -1,47 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"]
_import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson
from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,373 +0,0 @@
import json
import math
import textwrap
from typing import Any, Dict, Iterable, List, Optional
import torch
from boltons.iterutils import remap
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
def parse_aesthetic_score(record: dict) -> str:
ae = record["aesthetic_score"]
if ae < 5.5:
return "very low"
elif ae < 6:
return "low"
elif ae < 7:
return "medium"
elif ae < 7.6:
return "high"
else:
return "very high"
def parse_pickascore(record: dict) -> str:
ps = record["pickascore"]
if ps < 0.78:
return "very low"
elif ps < 0.82:
return "low"
elif ps < 0.87:
return "medium"
elif ps < 0.91:
return "high"
else:
return "very high"
def prepare_clean_caption(record: dict) -> str:
def keep(p, k, v):
is_none = v is None
is_empty_string = isinstance(v, str) and v == ""
is_empty_dict = isinstance(v, dict) and not v
is_empty_list = isinstance(v, list) and not v
is_nan = isinstance(v, float) and math.isnan(v)
if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan:
return False
return True
try:
scores = {}
if "pickascore" in record:
scores["preference_score"] = parse_pickascore(record)
if "aesthetic_score" in record:
scores["aesthetic_score"] = parse_aesthetic_score(record)
clean_caption_dict = remap(record, visit=keep)
# Set aesthetics scores
if "aesthetics" not in clean_caption_dict:
if len(scores) > 0:
clean_caption_dict["aesthetics"] = scores
else:
clean_caption_dict["aesthetics"].update(scores)
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
clean_caption_str = json.dumps(clean_caption_dict)
return clean_caption_str
except Exception as ex:
print("Error: ", ex)
raise ex
def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]:
images: List[Image.Image] = []
for message in messages:
content = message.get("content", [])
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") != "image":
continue
image_value = item.get("image")
if isinstance(image_value, Image.Image):
images.append(image_value)
else:
raise ValueError("Expected PIL.Image for image content in messages.")
return images
def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str:
if not stop_sequences:
return text.strip()
cleaned = text
for stop in stop_sequences:
if not stop:
continue
index = cleaned.find(stop)
if index >= 0:
cleaned = cleaned[:index]
return cleaned.strip()
class TransformersEngine(torch.nn.Module):
"""Inference wrapper using Hugging Face transformers."""
def __init__(
self,
model: str,
*,
processor_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super(TransformersEngine, self).__init__()
default_processor_kwargs: Dict[str, Any] = {
"min_pixels": 256 * 28 * 28,
"max_pixels": 1024 * 28 * 28,
}
processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})}
model_kwargs = model_kwargs or {}
self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs)
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model,
dtype=torch.bfloat16,
**model_kwargs,
)
self.model.eval()
tokenizer_obj = self.processor.tokenizer
if tokenizer_obj.pad_token_id is None:
tokenizer_obj.pad_token = tokenizer_obj.eos_token
self._pad_token_id = tokenizer_obj.pad_token_id
eos_token_id = tokenizer_obj.eos_token_id
if isinstance(eos_token_id, list) and eos_token_id:
self._eos_token_id = eos_token_id
elif eos_token_id is not None:
self._eos_token_id = [eos_token_id]
else:
raise ValueError("Tokenizer must define an EOS token for generation.")
def dtype(self) -> torch.dtype:
return self.model.dtype
def device(self) -> torch.device:
return self.model.device
def _to_model_device(self, value: Any) -> Any:
if not isinstance(value, torch.Tensor):
return value
target_device = getattr(self.model, "device", None)
if target_device is None or target_device.type == "meta":
return value
if value.device == target_device:
return value
return value.to(target_device)
def generate(
self,
messages: List[Dict[str, Any]],
top_p: float,
temperature: float,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> str:
tokenizer = self.processor.tokenizer
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
processor_inputs: Dict[str, Any] = {
"text": [prompt_text],
"padding": True,
"return_tensors": "pt",
}
images = _collect_images(messages)
if images:
processor_inputs["images"] = images
inputs = self.processor(**processor_inputs)
inputs = {key: self._to_model_device(value) for key, value in inputs.items()}
generation_kwargs: Dict[str, Any] = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"eos_token_id": self._eos_token_id,
"pad_token_id": self._pad_token_id,
}
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, **generation_kwargs)
input_ids = inputs.get("input_ids")
if input_ids is None:
raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.")
new_token_ids = generated_ids[:, input_ids.shape[-1] :]
decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
if not decoded:
return ""
text = decoded[0]
stripped_text = _strip_stop_sequences(text, stop)
json_prompt = json.loads(stripped_text)
return json_prompt
def generate_json_prompt(
vlm_processor: AutoModelForCausalLM,
top_p: float,
temperature: float,
max_tokens: int,
stop: List[str],
image: Optional[Image.Image] = None,
prompt: Optional[str] = None,
structured_prompt: Optional[str] = None,
):
if image is None and structured_prompt is None:
# only got prompt
task = "generate"
editing_instructions = None
elif image is None and structured_prompt is not None and prompt is not None:
# got structured prompt and prompt
task = "refine"
editing_instructions = prompt
elif image is not None and structured_prompt is None and prompt is not None:
# got image and prompt
task = "refine"
editing_instructions = prompt
elif image is not None and structured_prompt is None and prompt is None:
# only got image
task = "inspire"
editing_instructions = None
else:
raise ValueError("Invalid input")
messages = build_messages(
task,
image=image,
prompt=prompt,
structured_prompt=structured_prompt,
editing_instructions=editing_instructions,
)
generated_prompt = vlm_processor.generate(
messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop
)
cleaned_json_data = prepare_clean_caption(generated_prompt)
return cleaned_json_data
def build_messages(
task: str,
*,
image: Optional[Image.Image] = None,
refine_image: Optional[Image.Image] = None,
prompt: Optional[str] = None,
structured_prompt: Optional[str] = None,
editing_instructions: Optional[str] = None,
) -> List[Dict[str, Any]]:
user_content: List[Dict[str, Any]] = []
if task == "inspire":
user_content.append({"type": "image", "image": image})
user_content.append({"type": "text", "text": "<inspire>"})
elif task == "generate":
text_value = (prompt or "").strip()
formatted = f"<generate>\n{text_value}"
user_content.append({"type": "text", "text": formatted})
else: # refine
if refine_image is None:
base_prompt = (structured_prompt or "").strip()
edits = (editing_instructions or "").strip()
formatted = textwrap.dedent(f"""<refine> Input: {base_prompt} Editing instructions: {edits}""").strip()
user_content.append({"type": "text", "text": formatted})
else:
user_content.append({"type": "image", "image": refine_image})
edits = (editing_instructions or "").strip()
formatted = textwrap.dedent(f"""<refine> Editing instructions: {edits}""").strip()
user_content.append({"type": "text", "text": formatted})
messages: List[Dict[str, Any]] = []
messages.append({"role": "user", "content": user_content})
return messages
class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
model_name = "BriaFibo"
def __init__(self, model_id):
super().__init__()
self.engine = TransformersEngine(model_id)
self.engine.model.to("cuda")
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
prompt_input = InputParam(
"prompt",
type_hint=str,
required=False,
description="Prompt to use",
)
image_input = InputParam(
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
)
json_prompt_input = InputParam(
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
)
sampling_top_p_input = InputParam(
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9
)
sampling_temperature_input = InputParam(
name="sampling_temperature",
type_hint=float,
required=False,
description="Sampling temperature",
default=0.2,
)
sampling_max_tokens_input = InputParam(
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096
)
return [
prompt_input,
image_input,
json_prompt_input,
sampling_top_p_input,
sampling_temperature_input,
sampling_max_tokens_input,
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"json_prompt",
type_hint=str,
description="JSON prompt by the VLM",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt = block_state.prompt
image = block_state.image
json_prompt = block_state.json_prompt
block_state.json_prompt = generate_json_prompt(
vlm_processor=self.engine,
image=image,
prompt=prompt,
structured_prompt=json_prompt,
top_p=block_state.sampling_top_p,
temperature=block_state.sampling_temperature,
max_tokens=block_state.sampling_max_tokens,
stop=["<|im_end|>", "<|end_of_text|>"],
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,804 +0,0 @@
import io
import json
import math
import os
from functools import cache
from typing import List, Optional, Tuple
from boltons.iterutils import remap
from google import genai
from PIL import Image
from pydantic import BaseModel, Field
from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState
class ObjectDescription(BaseModel):
description: str = Field(..., description="Short description of the object.")
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
relationship: str = Field(
..., description="Describe the relationship between the object and the other objects in the image."
)
relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.")
shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.")
texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.")
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
# If cluster of object
number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.")
# Human-specific fields
pose: Optional[str] = Field(None, description="Describe the body position.")
expression: Optional[str] = Field(None, description="Describe facial expression.")
clothing: Optional[str] = Field(None, description="Describe attire.")
action: Optional[str] = Field(None, description="Describe the action of the human.")
gender: Optional[str] = Field(None, description="Describe the gender of the human.")
skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.")
orientation: Optional[str] = Field(None, description="Describe the orientation of the human.")
class LightingDetails(BaseModel):
conditions: str = Field(
..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'."
)
direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.")
shadows: Optional[str] = Field(None, description="Describe the presence of shadows.")
class AestheticsDetails(BaseModel):
composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.")
color_scheme: str = Field(
..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'."
)
mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.")
class PhotographicCharacteristicsDetails(BaseModel):
depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.")
focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.")
camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.")
lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.")
class TextRender(BaseModel):
text: str = Field(..., description="The text content.")
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.")
color: str = Field(..., description="E.g., 'red', 'blue', 'green'.")
font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.")
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
class ImageAnalysis(BaseModel):
short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.")
objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.")
background_setting: str = Field(
...,
description="Describe the overall environment, setting, or background, including any notable background elements.",
)
lighting: LightingDetails = Field(..., description="Details about the lighting.")
aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.")
photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field(
None, description="Details about photographic characteristics."
)
style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.")
text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.")
context: str = Field(..., description="Provide any additional context that helps understand the image better.")
artistic_style: Optional[str] = Field(
None, description="describe specific artistic characteristics, 3 words maximum."
)
def get_gemini_output_schema() -> dict:
return {
"properties": {
"short_description": {"type": "STRING"},
"objects": {
"items": {
"properties": {
"description": {"type": "STRING"},
"location": {"type": "STRING"},
"relationship": {"type": "STRING"},
"relative_size": {"type": "STRING"},
"shape_and_color": {"type": "STRING"},
"texture": {"nullable": True, "type": "STRING"},
"appearance_details": {"nullable": True, "type": "STRING"},
"number_of_objects": {"nullable": True, "type": "INTEGER"},
"pose": {"nullable": True, "type": "STRING"},
"expression": {"nullable": True, "type": "STRING"},
"clothing": {"nullable": True, "type": "STRING"},
"action": {"nullable": True, "type": "STRING"},
"gender": {"nullable": True, "type": "STRING"},
"skin_tone_and_texture": {"nullable": True, "type": "STRING"},
"orientation": {"nullable": True, "type": "STRING"},
},
"required": [
"description",
"location",
"relationship",
"relative_size",
"shape_and_color",
"texture",
"appearance_details",
"number_of_objects",
"pose",
"expression",
"clothing",
"action",
"gender",
"skin_tone_and_texture",
"orientation",
],
"type": "OBJECT",
},
"type": "ARRAY",
},
"background_setting": {"type": "STRING"},
"lighting": {
"properties": {
"conditions": {"type": "STRING"},
"direction": {"type": "STRING"},
"shadows": {"nullable": True, "type": "STRING"},
},
"required": ["conditions", "direction", "shadows"],
"type": "OBJECT",
},
"aesthetics": {
"properties": {
"composition": {"type": "STRING"},
"color_scheme": {"type": "STRING"},
"mood_atmosphere": {"type": "STRING"},
},
"required": ["composition", "color_scheme", "mood_atmosphere"],
"type": "OBJECT",
},
"photographic_characteristics": {
"nullable": True,
"properties": {
"depth_of_field": {"type": "STRING"},
"focus": {"type": "STRING"},
"camera_angle": {"type": "STRING"},
"lens_focal_length": {"type": "STRING"},
},
"required": [
"depth_of_field",
"focus",
"camera_angle",
"lens_focal_length",
],
"type": "OBJECT",
},
"style_medium": {"type": "STRING"},
"text_render": {
"items": {
"properties": {
"text": {"type": "STRING"},
"location": {"type": "STRING"},
"size": {"type": "STRING"},
"color": {"type": "STRING"},
"font": {"type": "STRING"},
"appearance_details": {"nullable": True, "type": "STRING"},
},
"required": [
"text",
"location",
"size",
"color",
"font",
"appearance_details",
],
"type": "OBJECT",
},
"type": "ARRAY",
},
"context": {"type": "STRING"},
"artistic_style": {"type": "STRING"},
},
"required": [
"short_description",
"objects",
"background_setting",
"lighting",
"aesthetics",
"photographic_characteristics",
"style_medium",
"text_render",
"context",
"artistic_style",
],
"type": "OBJECT",
}
json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum.
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively
choose the most important ones and describe the rest in the background. For each object, include:
* `description`: (String) A detailed description of the imagined object, 100 words maximum.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this
should be "medium-to-large" or "large within frame").
* `shape_and_color`: (String) Describe the basic shape and dominant color.
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
* `appearance_details`: (String) Any other notable visual details.
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
* If the object is a human or a human-like object, include the following:
* `pose`: (String) Describe the body position.
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
"surprised", "calm".
* `clothing`: (String) Describe attire.
* `action`: (String) Describe the action of the human.
* `gender`: (String) Describe the gender of the human.
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
* If the object is a cluster of objects, include the following:
* `number_of_objects`: (Integer) The number of objects in the cluster.
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
background elements that are not part of the `objects` section.
4. `lighting`: (Object)
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
* `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined
shadows", "minimal shadows".
5. `aesthetics`: (Object)
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the
main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition".
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
6. `photographic_characteristics`: (Object)
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a
person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed
more closely. Avoid "wide-angle" for people unless specified).
7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative
interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil
sketch").
8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3
words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting).
9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a
concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of
concept art for a fantasy video game."
10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this
array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a
poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or
posters unless the user provides them. A rare exception is for universally recognized text that is integral to an
object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array
must be empty.
* `text`: (String) The exact text content provided by the user. NEVER use generic placeholders.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `size`: (String) E.g., "medium", "large", "large within frame".
* `color`: (String) E.g., "red", "blue", "green".
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface".
* `appearance_details`: (String) Any other notable visual details."""
@cache
def get_instructions(mode: str) -> Tuple[str, str]:
system_prompts = {}
system_prompts["Caption"] = """
You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise
lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided
images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and
guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g.,
no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or
actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a
person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right
side of the image from the viewer's perspective. The JSON object must contain the following keys precisely:
1. `short_description`: (String) A concise summary of the image content, 200 words maximum.
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the
background. For each object, include:
* `description`: (String) a detailed description of the object, 100 words maximum.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `relative_size`: (String) E.g., "small", "medium", "large within frame".
* `shape_and_color`: (String) Describe the basic shape and dominant color.
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
* `appearance_details`: (String) Any other notable visual details.
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
if the object is a human or a human-like object, include the following:
* `pose`: (String) Describe the body position.
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
"surprised", "calm".
* `clothing`: (String) Describe attire.
* `action`: (String) Describe the action of the human.
* `gender`: (String) Describe the gender of the human.
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
if the object is a cluster of objects, include the following:
* `number_of_objects`: (Integer) The number of objects in the cluster.
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
background elements that are not part of the objects section.
4. `lighting`: (Object)
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
* `shadows`: (String) Describe the presence of shadows.
5. `aesthetics`: (Object)
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines".
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
6. `photographic_characteristics`: (Object)
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye".
7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor",
"3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please
describe the specific artistic characteristics under 'artistic_style', 50 words maximum.
8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum.
9. `context`: (String) Provide any additional context that helps understand the image better. This should include a
general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature
Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a
broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine
spread"
10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text
render, include:
* `text`: (String) The text content.
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
* `size`: (String) E.g., "small", "medium", "large within frame".
* `color`: (String) E.g., "red", "blue", "green".
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist".
* `appearance_details`: (String) Any other notable visual details.
Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields.
"""
system_prompts[
"Generate"
] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company.
Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically
compelling visual scene.
Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that
describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly
mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming
for a high-quality, visually appealing result unless the user's prompt suggests otherwise.
Adhere strictly to the following structure and guidelines:
The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is
the JSON:", no explanations, no apologies).
IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN
PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own
perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's
perspective.
RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a
composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus
(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small,
unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain").
Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST
default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style`
should be "realistic" or a similar descriptor.
The JSON object must contain the following keys precisely:
{json_schema_full}
Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between
fields."""
system_prompts[
"RefineA"
] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company.
Your expertise is in refining and modifying existing visual concepts based on precise feedback.
Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction
for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested
changes.
Adhere strictly to the following structure and guidelines:
1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction.
2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
before or after the JSON object.
3. **Modification Logic:**
* Carefully parse the user's textual instruction to understand the desired changes.
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
* All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit
them.
4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example:
* If the instruction is to "change the background to a snowy forest," you must update the `background_setting`
field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need
to change to "serene" or "wintry."
* If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render`
array.
* If the instruction is to "make the person smile," you must update the `expression` field for that object and
potentially update the overall `mood_atmosphere`.
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts[
"RefineB"
] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company.
Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON
metadata, and textual feedback simultaneously.
Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction
for a modification. You must use the image as the primary source of truth to understand the context of the requested
change and then generate a new, updated JSON object that accurately reflects that change.
Adhere strictly to the following structure and guidelines:
1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction.
2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of
the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make
the car blue," visually locate the car in the image to inform your edits to the JSON.
3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
before or after the JSON object.
4. **Modification Logic:**
* Analyze the user's textual instruction in the context of what you see in the image.
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
* All other fields not relevant to the change must be copied exactly from the original JSON.
5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual
change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update
the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`.
6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts[
"InspireA"
] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company.
Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct
a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual
instructions.
Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract
key visual information and then generate a comprehensive JSON object describing a new scene that creatively
incorporates the user's instructions.
Adhere strictly to the following structure and guidelines:
1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON.
2. **Core Logic (Analyze and Synthesize):**
* **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific
poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium.
* **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference
and what to change. You will then construct a brand new JSON object from scratch that describes the desired final
scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog
from the reference image in the `objects` array but create a new `background_setting` for a beach, with
appropriate `lighting` and `mood_atmosphere`.
3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not
describe the original reference image.
4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be
reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the
`short_description`.
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
The JSON object must contain the following keys precisely:
{json_schema_full}"""
system_prompts["InspireB"] = system_prompts["Caption"]
final_prompts = {}
final_prompts["Generate"] = (
"Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}."
)
final_prompts["RefineA"] = """
[EXISTING JSON]: {json_data}
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules
for modification, consistency, and formatting.
"""
final_prompts["RefineB"] = """
[EXISTING JSON]: {json_data}
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that
incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and
consistency.
"""
final_prompts["InspireA"] = """
[EDIT INSTRUCTIONS]: {user_prompt}
[TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose)
and then generate a new, detailed JSON object for the scene described in the instructions above. Do not
describe the reference image itself; describe the new scene. Follow all of your system rules.
"""
final_prompts["Caption"] = (
"Analyze the provided image and generate the detailed JSON object as specified in your instructions."
)
final_prompts["InspireB"] = final_prompts["Caption"]
return system_prompts.get(mode, ""), final_prompts.get(mode, "")
def keep(p, k, v):
is_none = v is None
is_empty_string = isinstance(v, str) and v == ""
is_empty_dict = isinstance(v, dict) and not v
is_empty_list = isinstance(v, list) and not v
is_nan = isinstance(v, float) and math.isnan(v)
if is_none or is_empty_string or is_empty_list or is_empty_dict:
return False
if is_nan:
return False
return True
def validate_json(json_data: dict) -> dict:
ia = ImageAnalysis.model_validate_json(json_data, strict=True)
return ia.model_dump(exclude_none=True)
def validate_structured_prompt_str(structured_prompt_str: str) -> str:
ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True)
c = ia.model_dump(exclude_none=True)
return json.dumps(c)
def prepare_clean_caption(json_dump: dict) -> str:
# filter empty values recursivly (i.e. None, "", {}, [], float("nan"))
clean_caption_dict = remap(json_dump, visit=keep)
scores = {"preference_score": "very high", "aesthetic_score": "very high"}
# Set aesthetics scores
if "aesthetics" not in clean_caption_dict:
clean_caption_dict["aesthetics"] = scores
else:
clean_caption_dict["aesthetics"].update(scores)
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
clean_caption_str = json.dumps(clean_caption_dict)
return clean_caption_str
# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024)
# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number).
def resize_image_by_num_pixels(
image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0
) -> Image.Image:
if target_ratio != 0.0:
ratio = target_ratio
else:
ratio = image.size[0] / image.size[1]
width = int((pixel_number * ratio) ** 0.5)
width = width - (width % granularity_val)
height = int(pixel_number / width)
height = height - (height % granularity_val)
return image.resize((width, height))
def infer_with_gemini(
client: genai.Client,
final_prompt: str,
system_prompt: str,
top_p: float,
temperature: float,
max_tokens: int,
seed: int = 42,
image: Optional[Image.Image] = None,
model: str = "gemini-2.5-flash",
) -> str:
"""
Calls Gemini API with the given prompt and returns the raw JSON response.
Args:
final_prompt: The text prompt to send to Gemini
system_prompt: The system instruction for Gemini
existing_image_path: Optional path to an image file to include
model: The Gemini model to use
Returns:
Raw JSON response text from Gemini
"""
parts = [{"text": final_prompt}]
if image:
# Save image into bytes
image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect
less_then = 262144
if image.size[0] * image.size[1] > less_then:
image = resize_image_by_num_pixels(
image, pixel_number=less_then, granularity_val=1, target_ratio=0.0
) # 512x512
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
img_part = {
"inlineData": {
"data": image_bytes,
"mimeType": "image/jpeg",
}
}
parts.append(img_part)
contents = [{"role": "user", "parts": parts}]
generationConfig = {
"temperature": temperature,
"topP": top_p,
"maxOutputTokens": max_tokens,
"response_mime_type": "application/json",
"response_schema": get_gemini_output_schema(),
"system_instruction": system_prompt, # len 5900
"thinkingConfig": {"thinkingBudget": 0},
"seed": seed,
}
response = client.models.generate_content(
model=model,
contents=contents,
config=generationConfig,
)
if response.candidates[0].finish_reason == "MAX_TOKENS":
raise Exception("Max tokens")
return response.candidates[0].content.parts[0].text
def get_default_negative_prompt(existing_json: dict) -> str:
negative_prompt = ""
style_medium = existing_json.get("style_medium", "").lower()
if style_medium in ["photograph", "photography", "photo"]:
negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}"""
return negative_prompt
def json_promptify(
client: genai.Client,
model_id: str,
top_p: float,
temperature: float,
max_tokens: int,
user_prompt: Optional[str] = None,
existing_json: Optional[str] = None,
image: Optional[Image.Image] = None,
seed: int = 42,
) -> str:
if existing_json:
# make sure aesthetic scores are not in the existing json (will be added later)
existing_json = json.loads(existing_json)
if "aesthetics" in existing_json:
existing_json["aesthetics"].pop("aesthetic_score", None)
existing_json["aesthetics"].pop("preference_score", None)
existing_json = json.dumps(existing_json)
if not user_prompt:
raise ValueError("user_prompt is required if existing_json is provided")
if image:
mode = "RefineB"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
else:
mode = "RefineA"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
elif image and user_prompt:
mode = "InspireA"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt)
elif image and not user_prompt:
mode = "Caption"
system_prompt, final_prompt = get_instructions(mode)
else:
mode = "Generate"
system_prompt, final_prompt = get_instructions(mode)
final_prompt = final_prompt.format(user_prompt=user_prompt)
json_data = infer_with_gemini(
client=client,
model=model_id,
final_prompt=final_prompt,
system_prompt=system_prompt,
seed=seed,
image=image,
top_p=top_p,
temperature=temperature,
max_tokens=max_tokens,
)
json_data = validate_json(json_data)
clean_caption = prepare_clean_caption(json_data)
return clean_caption
class BriaFiboGeminiPromptToJson(ModularPipelineBlocks):
model_name = "BriaFibo"
def __init__(self, model_id="gemini-2.5-flash"):
super().__init__()
api_key = os.getenv("GOOGLE_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
self.model_id = model_id
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute")
prompt_input = InputParam(
"prompt",
type_hint=str,
required=False,
description="Prompt to use",
)
image_input = InputParam(
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
)
json_prompt_input = InputParam(
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
)
sampling_top_p_input = InputParam(
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0
)
sampling_temperature_input = InputParam(
name="sampling_temperature",
type_hint=float,
required=False,
description="Sampling temperature",
default=0.2,
)
sampling_max_tokens_input = InputParam(
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000
)
return [
task_input,
prompt_input,
image_input,
json_prompt_input,
sampling_top_p_input,
sampling_temperature_input,
sampling_max_tokens_input,
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"json_prompt",
type_hint=str,
description="JSON prompt by the VLM",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt = block_state.prompt
image = block_state.image
json_prompt = block_state.json_prompt
client = genai.Client()
json_prompt = json_promptify(
client=client,
model_id=self.model_id,
top_p=block_state.sampling_top_p,
temperature=block_state.sampling_temperature,
max_tokens=block_state.sampling_max_tokens,
user_prompt=prompt,
existing_json=json_prompt,
image=image,
)
block_state.json_prompt = json_prompt
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,3 +1,13 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
@@ -35,20 +45,47 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
...
Example:
```python
import torch
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
pipe = BriaFiboPipeline.from_pretrained(
"briaai/FIBO",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.enable_sequential_cpu_offload()
with torch.inference_mode():
# 1. Create a prompt to generate an initial image
output = vlm_pipe(prompt="a beautiful dog")
json_prompt_generate = output.values["json_prompt"]
# Generate the image from the structured json prompt
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
results_generate.images[0].save("image_generate.png")
```
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer ([`GaiaTransformer2DModel`]):
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`SmolLM3ForCausalLM`]):
transformer (`BriaFiboTransformer2DModel`):
The transformer model for 2D diffusion modeling.
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
Scheduler to be used with `transformer` to denoise the encoded latents.
vae (`AutoencoderKLWan`):
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
text_encoder (`SmolLM3ForCausalLM`):
Text encoder for processing input prompts.
tokenizer (`AutoTokenizer`):
Tokenizer used for processing the input text prompts for the text_encoder.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
@@ -166,7 +203,7 @@ class BriaFiboPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -181,8 +218,8 @@ class BriaFiboPipeline(DiffusionPipeline):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
guidance_scale (`float`):
Guidance scale for classifier free guidance.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -224,7 +261,7 @@ class BriaFiboPipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if do_classifier_free_guidance:
if guidance_scale > 1:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
@@ -302,9 +339,6 @@ class BriaFiboPipeline(DiffusionPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
@@ -320,6 +354,7 @@ class BriaFiboPipeline(DiffusionPipeline):
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
@@ -366,6 +401,7 @@ class BriaFiboPipeline(DiffusionPipeline):
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
@@ -410,34 +446,7 @@ class BriaFiboPipeline(DiffusionPipeline):
return latents, latent_image_ids
@staticmethod
def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
assert height % 16 == 0 and width % 16 == 0
mu = calculate_shift(
image_seq_len,
noise_scheduler.config.base_image_seq_len,
noise_scheduler.config.max_image_seq_len,
noise_scheduler.config.base_shift,
noise_scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
noise_scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
return noise_scheduler, timesteps, num_inference_steps, mu
@staticmethod
def create_attention_matrix(attention_mask):
def _prepare_attention_mask(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
@@ -583,7 +592,7 @@ class BriaFiboPipeline(DiffusionPipeline):
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
@@ -593,7 +602,7 @@ class BriaFiboPipeline(DiffusionPipeline):
)
prompt_batch_size = prompt_embeds.shape[0]
if self.do_classifier_free_guidance:
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
@@ -611,6 +620,7 @@ class BriaFiboPipeline(DiffusionPipeline):
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
@@ -630,11 +640,11 @@ class BriaFiboPipeline(DiffusionPipeline):
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if self.do_classifier_free_guidance:
if guidance_scale > 1:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
@@ -648,13 +658,25 @@ class BriaFiboPipeline(DiffusionPipeline):
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler(
height=height,
width=width,
device=device,
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
noise_scheduler=self.scheduler,
image_seq_len=seq_len,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -674,10 +696,7 @@ class BriaFiboPipeline(DiffusionPipeline):
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
@@ -697,7 +716,7 @@ class BriaFiboPipeline(DiffusionPipeline):
)[0]
# perform guidance
if self.do_classifier_free_guidance:
if guidance_scale > 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -736,7 +755,6 @@ class BriaFiboPipeline(DiffusionPipeline):
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.to(dtype=self.vae.dtype)
latents = latents.unsqueeze(dim=2)
latents = list(torch.unbind(latents, dim=0))
latents_device = latents[0].device
@@ -780,8 +798,8 @@ class BriaFiboPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -818,9 +836,3 @@ class BriaFiboPipeline(DiffusionPipeline):
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
def to(self, *args, **kwargs):
DiffusionPipeline.to(self, *args, **kwargs)
# We use as float32 since wan22 in their repo use it like this
self.vae.to(dtype=torch.float32)
return self

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import BriaFiboTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
@@ -84,48 +84,6 @@ class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict)[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()

View File

@@ -128,7 +128,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32)]
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
@@ -181,13 +181,18 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# def test_to_dtype(self):
# components = self.get_dummy_components()
# pipe = self.pipeline_class(**components)
# pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
# self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
# pipe.to(dtype=torch.float16)
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
# self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
@unittest.skip("")
def test_save_load_dduf(self):
pass