From a617433aceb67d2bab269a7cfcb04f75d5443612 Mon Sep 17 00:00:00 2001 From: galbria Date: Mon, 27 Oct 2025 13:04:57 +0000 Subject: [PATCH] fix CR --- docs/source/en/_toctree.yml | 4 + .../transformers/transformer_bria_fibo.py | 349 ++++++-- .../modular_pipelines/bria_fibo/__init__.py | 47 - .../bria_fibo/fibo_vlm_prompt_to_json.py | 373 -------- .../bria_fibo/gemini_prompt_to_json.py | 804 ------------------ .../pipelines/bria_fibo/pipeline_bria_fibo.py | 144 ++-- .../test_models_transformer_bria_fibo.py | 44 +- .../bria_fibo/test_pipeline_bria_fibo.py | 19 +- 8 files changed, 375 insertions(+), 1409 deletions(-) delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/__init__.py delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py delete mode 100644 src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 540e99a2c6..a5f0efe02f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -323,6 +323,8 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel + - local: api/models/transformer_bria_fibo + title: BriaFiboTransformer2DModel - local: api/models/bria_transformer title: BriaTransformer2DModel - local: api/models/chroma_transformer @@ -469,6 +471,8 @@ title: BLIP-Diffusion - local: api/pipelines/bria_3_2 title: Bria 3.2 + - local: api/pipelines/bria_fibo + title: Bria Fibo - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 9521b7f3dd..e1bfde9555 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -1,18 +1,26 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. +import inspect from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention_processor import Attention -from ...models.embeddings import TimestepEmbedding, get_timestep_embedding +from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle from ...models.transformers.transformer_bria import BriaAttnProcessor -from ...models.transformers.transformer_flux import FluxTransformerBlock from ...utils import ( USE_PEFT_BACKEND, logging, @@ -20,76 +28,193 @@ from ...utils import ( unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32, torch.float64 -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a - frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' - parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. +def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): - the dtype of the frequency tensor. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] - - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] - return freqs_cos, freqs_sin - elif use_real: - # stable audio, allegro - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] - return freqs_cos, freqs_sin - else: - # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis + return query, key, value, encoder_query, encoder_key, encoder_value -class EmbedND(torch.nn.Module): +def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class BriaFiboAttnProcessor: + # Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "BriaFiboAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): + # Copied from diffusers.models.transformers.transformer_flux.FluxAttention + _default_processor_cls = BriaFiboAttnProcessor + _available_processors = [ + BriaFiboAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class FIBOEmbedND(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -182,7 +307,93 @@ class TextProjection(nn.Module): return hidden_states -class Timesteps(nn.Module): +@maybe_allow_in_graph +class BriaFiboTransformerBlock(nn.Module): + # Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = BriaFiboAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=BriaFiboAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FIBOTimesteps(nn.Module): def __init__( self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 ): @@ -209,7 +420,7 @@ class TimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, time_theta): super().__init__() - self.time_proj = Timesteps( + self.time_proj = FIBOTimesteps( num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta ) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -258,7 +469,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) @@ -270,7 +481,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From self.transformer_blocks = nn.ModuleList( [ - FluxTransformerBlock( + BriaFiboTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, diff --git a/src/diffusers/modular_pipelines/bria_fibo/__init__.py b/src/diffusers/modular_pipelines/bria_fibo/__init__.py deleted file mode 100644 index 770cb9391a..0000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"] - _import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"] - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 - else: - from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson - from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py deleted file mode 100644 index 689e4ae59a..0000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/fibo_vlm_prompt_to_json.py +++ /dev/null @@ -1,373 +0,0 @@ -import json -import math -import textwrap -from typing import Any, Dict, Iterable, List, Optional - -import torch -from boltons.iterutils import remap -from PIL import Image -from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration - -from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState - - -def parse_aesthetic_score(record: dict) -> str: - ae = record["aesthetic_score"] - if ae < 5.5: - return "very low" - elif ae < 6: - return "low" - elif ae < 7: - return "medium" - elif ae < 7.6: - return "high" - else: - return "very high" - - -def parse_pickascore(record: dict) -> str: - ps = record["pickascore"] - if ps < 0.78: - return "very low" - elif ps < 0.82: - return "low" - elif ps < 0.87: - return "medium" - elif ps < 0.91: - return "high" - else: - return "very high" - - -def prepare_clean_caption(record: dict) -> str: - def keep(p, k, v): - is_none = v is None - is_empty_string = isinstance(v, str) and v == "" - is_empty_dict = isinstance(v, dict) and not v - is_empty_list = isinstance(v, list) and not v - is_nan = isinstance(v, float) and math.isnan(v) - if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan: - return False - return True - - try: - scores = {} - if "pickascore" in record: - scores["preference_score"] = parse_pickascore(record) - if "aesthetic_score" in record: - scores["aesthetic_score"] = parse_aesthetic_score(record) - - clean_caption_dict = remap(record, visit=keep) - - # Set aesthetics scores - if "aesthetics" not in clean_caption_dict: - if len(scores) > 0: - clean_caption_dict["aesthetics"] = scores - else: - clean_caption_dict["aesthetics"].update(scores) - - # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) - clean_caption_str = json.dumps(clean_caption_dict) - return clean_caption_str - except Exception as ex: - print("Error: ", ex) - raise ex - - -def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]: - images: List[Image.Image] = [] - for message in messages: - content = message.get("content", []) - if not isinstance(content, list): - continue - for item in content: - if not isinstance(item, dict): - continue - if item.get("type") != "image": - continue - image_value = item.get("image") - if isinstance(image_value, Image.Image): - images.append(image_value) - else: - raise ValueError("Expected PIL.Image for image content in messages.") - return images - - -def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str: - if not stop_sequences: - return text.strip() - cleaned = text - for stop in stop_sequences: - if not stop: - continue - index = cleaned.find(stop) - if index >= 0: - cleaned = cleaned[:index] - return cleaned.strip() - - -class TransformersEngine(torch.nn.Module): - """Inference wrapper using Hugging Face transformers.""" - - def __init__( - self, - model: str, - *, - processor_kwargs: Optional[Dict[str, Any]] = None, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - super(TransformersEngine, self).__init__() - default_processor_kwargs: Dict[str, Any] = { - "min_pixels": 256 * 28 * 28, - "max_pixels": 1024 * 28 * 28, - } - processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})} - model_kwargs = model_kwargs or {} - - self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs) - - self.model = Qwen3VLForConditionalGeneration.from_pretrained( - model, - dtype=torch.bfloat16, - **model_kwargs, - ) - self.model.eval() - - tokenizer_obj = self.processor.tokenizer - if tokenizer_obj.pad_token_id is None: - tokenizer_obj.pad_token = tokenizer_obj.eos_token - self._pad_token_id = tokenizer_obj.pad_token_id - eos_token_id = tokenizer_obj.eos_token_id - if isinstance(eos_token_id, list) and eos_token_id: - self._eos_token_id = eos_token_id - elif eos_token_id is not None: - self._eos_token_id = [eos_token_id] - else: - raise ValueError("Tokenizer must define an EOS token for generation.") - - def dtype(self) -> torch.dtype: - return self.model.dtype - - def device(self) -> torch.device: - return self.model.device - - def _to_model_device(self, value: Any) -> Any: - if not isinstance(value, torch.Tensor): - return value - target_device = getattr(self.model, "device", None) - if target_device is None or target_device.type == "meta": - return value - if value.device == target_device: - return value - return value.to(target_device) - - def generate( - self, - messages: List[Dict[str, Any]], - top_p: float, - temperature: float, - max_tokens: int, - stop: Optional[List[str]] = None, - ) -> str: - tokenizer = self.processor.tokenizer - prompt_text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - processor_inputs: Dict[str, Any] = { - "text": [prompt_text], - "padding": True, - "return_tensors": "pt", - } - images = _collect_images(messages) - if images: - processor_inputs["images"] = images - inputs = self.processor(**processor_inputs) - inputs = {key: self._to_model_device(value) for key, value in inputs.items()} - - generation_kwargs: Dict[str, Any] = { - "max_new_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "do_sample": temperature > 0, - "eos_token_id": self._eos_token_id, - "pad_token_id": self._pad_token_id, - } - - with torch.inference_mode(): - generated_ids = self.model.generate(**inputs, **generation_kwargs) - - input_ids = inputs.get("input_ids") - if input_ids is None: - raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.") - new_token_ids = generated_ids[:, input_ids.shape[-1] :] - decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True) - if not decoded: - return "" - text = decoded[0] - stripped_text = _strip_stop_sequences(text, stop) - json_prompt = json.loads(stripped_text) - return json_prompt - - -def generate_json_prompt( - vlm_processor: AutoModelForCausalLM, - top_p: float, - temperature: float, - max_tokens: int, - stop: List[str], - image: Optional[Image.Image] = None, - prompt: Optional[str] = None, - structured_prompt: Optional[str] = None, -): - if image is None and structured_prompt is None: - # only got prompt - task = "generate" - editing_instructions = None - elif image is None and structured_prompt is not None and prompt is not None: - # got structured prompt and prompt - task = "refine" - editing_instructions = prompt - elif image is not None and structured_prompt is None and prompt is not None: - # got image and prompt - task = "refine" - editing_instructions = prompt - elif image is not None and structured_prompt is None and prompt is None: - # only got image - task = "inspire" - editing_instructions = None - else: - raise ValueError("Invalid input") - - messages = build_messages( - task, - image=image, - prompt=prompt, - structured_prompt=structured_prompt, - editing_instructions=editing_instructions, - ) - - generated_prompt = vlm_processor.generate( - messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop - ) - cleaned_json_data = prepare_clean_caption(generated_prompt) - return cleaned_json_data - - -def build_messages( - task: str, - *, - image: Optional[Image.Image] = None, - refine_image: Optional[Image.Image] = None, - prompt: Optional[str] = None, - structured_prompt: Optional[str] = None, - editing_instructions: Optional[str] = None, -) -> List[Dict[str, Any]]: - user_content: List[Dict[str, Any]] = [] - - if task == "inspire": - user_content.append({"type": "image", "image": image}) - user_content.append({"type": "text", "text": ""}) - elif task == "generate": - text_value = (prompt or "").strip() - formatted = f"\n{text_value}" - user_content.append({"type": "text", "text": formatted}) - else: # refine - if refine_image is None: - base_prompt = (structured_prompt or "").strip() - edits = (editing_instructions or "").strip() - formatted = textwrap.dedent(f""" Input: {base_prompt} Editing instructions: {edits}""").strip() - user_content.append({"type": "text", "text": formatted}) - else: - user_content.append({"type": "image", "image": refine_image}) - edits = (editing_instructions or "").strip() - formatted = textwrap.dedent(f""" Editing instructions: {edits}""").strip() - user_content.append({"type": "text", "text": formatted}) - - messages: List[Dict[str, Any]] = [] - messages.append({"role": "user", "content": user_content}) - return messages - - -class BriaFiboVLMPromptToJson(ModularPipelineBlocks): - model_name = "BriaFibo" - - def __init__(self, model_id): - super().__init__() - self.engine = TransformersEngine(model_id) - self.engine.model.to("cuda") - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def inputs(self) -> List[InputParam]: - prompt_input = InputParam( - "prompt", - type_hint=str, - required=False, - description="Prompt to use", - ) - image_input = InputParam( - name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" - ) - json_prompt_input = InputParam( - name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" - ) - sampling_top_p_input = InputParam( - name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9 - ) - sampling_temperature_input = InputParam( - name="sampling_temperature", - type_hint=float, - required=False, - description="Sampling temperature", - default=0.2, - ) - sampling_max_tokens_input = InputParam( - name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096 - ) - return [ - prompt_input, - image_input, - json_prompt_input, - sampling_top_p_input, - sampling_temperature_input, - sampling_max_tokens_input, - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "json_prompt", - type_hint=str, - description="JSON prompt by the VLM", - ) - ] - - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - prompt = block_state.prompt - image = block_state.image - json_prompt = block_state.json_prompt - block_state.json_prompt = generate_json_prompt( - vlm_processor=self.engine, - image=image, - prompt=prompt, - structured_prompt=json_prompt, - top_p=block_state.sampling_top_p, - temperature=block_state.sampling_temperature, - max_tokens=block_state.sampling_max_tokens, - stop=["<|im_end|>", "<|end_of_text|>"], - ) - self.set_block_state(state, block_state) - - return components, state diff --git a/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py b/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py deleted file mode 100644 index dc69044483..0000000000 --- a/src/diffusers/modular_pipelines/bria_fibo/gemini_prompt_to_json.py +++ /dev/null @@ -1,804 +0,0 @@ -import io -import json -import math -import os -from functools import cache -from typing import List, Optional, Tuple - -from boltons.iterutils import remap -from google import genai -from PIL import Image -from pydantic import BaseModel, Field - -from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState - - -class ObjectDescription(BaseModel): - description: str = Field(..., description="Short description of the object.") - location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") - relationship: str = Field( - ..., description="Describe the relationship between the object and the other objects in the image." - ) - relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.") - shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.") - texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.") - appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") - # If cluster of object - number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.") - # Human-specific fields - pose: Optional[str] = Field(None, description="Describe the body position.") - expression: Optional[str] = Field(None, description="Describe facial expression.") - clothing: Optional[str] = Field(None, description="Describe attire.") - action: Optional[str] = Field(None, description="Describe the action of the human.") - gender: Optional[str] = Field(None, description="Describe the gender of the human.") - skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.") - orientation: Optional[str] = Field(None, description="Describe the orientation of the human.") - - -class LightingDetails(BaseModel): - conditions: str = Field( - ..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'." - ) - direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.") - shadows: Optional[str] = Field(None, description="Describe the presence of shadows.") - - -class AestheticsDetails(BaseModel): - composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.") - color_scheme: str = Field( - ..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'." - ) - mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.") - - -class PhotographicCharacteristicsDetails(BaseModel): - depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.") - focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.") - camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.") - lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.") - - -class TextRender(BaseModel): - text: str = Field(..., description="The text content.") - location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.") - size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.") - color: str = Field(..., description="E.g., 'red', 'blue', 'green'.") - font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.") - appearance_details: Optional[str] = Field(None, description="Any other notable visual details.") - - -class ImageAnalysis(BaseModel): - short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.") - objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.") - background_setting: str = Field( - ..., - description="Describe the overall environment, setting, or background, including any notable background elements.", - ) - lighting: LightingDetails = Field(..., description="Details about the lighting.") - aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.") - photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field( - None, description="Details about photographic characteristics." - ) - style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.") - text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.") - context: str = Field(..., description="Provide any additional context that helps understand the image better.") - artistic_style: Optional[str] = Field( - None, description="describe specific artistic characteristics, 3 words maximum." - ) - - -def get_gemini_output_schema() -> dict: - return { - "properties": { - "short_description": {"type": "STRING"}, - "objects": { - "items": { - "properties": { - "description": {"type": "STRING"}, - "location": {"type": "STRING"}, - "relationship": {"type": "STRING"}, - "relative_size": {"type": "STRING"}, - "shape_and_color": {"type": "STRING"}, - "texture": {"nullable": True, "type": "STRING"}, - "appearance_details": {"nullable": True, "type": "STRING"}, - "number_of_objects": {"nullable": True, "type": "INTEGER"}, - "pose": {"nullable": True, "type": "STRING"}, - "expression": {"nullable": True, "type": "STRING"}, - "clothing": {"nullable": True, "type": "STRING"}, - "action": {"nullable": True, "type": "STRING"}, - "gender": {"nullable": True, "type": "STRING"}, - "skin_tone_and_texture": {"nullable": True, "type": "STRING"}, - "orientation": {"nullable": True, "type": "STRING"}, - }, - "required": [ - "description", - "location", - "relationship", - "relative_size", - "shape_and_color", - "texture", - "appearance_details", - "number_of_objects", - "pose", - "expression", - "clothing", - "action", - "gender", - "skin_tone_and_texture", - "orientation", - ], - "type": "OBJECT", - }, - "type": "ARRAY", - }, - "background_setting": {"type": "STRING"}, - "lighting": { - "properties": { - "conditions": {"type": "STRING"}, - "direction": {"type": "STRING"}, - "shadows": {"nullable": True, "type": "STRING"}, - }, - "required": ["conditions", "direction", "shadows"], - "type": "OBJECT", - }, - "aesthetics": { - "properties": { - "composition": {"type": "STRING"}, - "color_scheme": {"type": "STRING"}, - "mood_atmosphere": {"type": "STRING"}, - }, - "required": ["composition", "color_scheme", "mood_atmosphere"], - "type": "OBJECT", - }, - "photographic_characteristics": { - "nullable": True, - "properties": { - "depth_of_field": {"type": "STRING"}, - "focus": {"type": "STRING"}, - "camera_angle": {"type": "STRING"}, - "lens_focal_length": {"type": "STRING"}, - }, - "required": [ - "depth_of_field", - "focus", - "camera_angle", - "lens_focal_length", - ], - "type": "OBJECT", - }, - "style_medium": {"type": "STRING"}, - "text_render": { - "items": { - "properties": { - "text": {"type": "STRING"}, - "location": {"type": "STRING"}, - "size": {"type": "STRING"}, - "color": {"type": "STRING"}, - "font": {"type": "STRING"}, - "appearance_details": {"nullable": True, "type": "STRING"}, - }, - "required": [ - "text", - "location", - "size", - "color", - "font", - "appearance_details", - ], - "type": "OBJECT", - }, - "type": "ARRAY", - }, - "context": {"type": "STRING"}, - "artistic_style": {"type": "STRING"}, - }, - "required": [ - "short_description", - "objects", - "background_setting", - "lighting", - "aesthetics", - "photographic_characteristics", - "style_medium", - "text_render", - "context", - "artistic_style", - ], - "type": "OBJECT", - } - - -json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum. -2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively - choose the most important ones and describe the rest in the background. For each object, include: - * `description`: (String) A detailed description of the imagined object, 100 words maximum. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this - should be "medium-to-large" or "large within frame"). - * `shape_and_color`: (String) Describe the basic shape and dominant color. - * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". - * `appearance_details`: (String) Any other notable visual details. - * `relationship`: (String) Describe the relationship between the object and the other objects in the image. - * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 - degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". - * If the object is a human or a human-like object, include the following: - * `pose`: (String) Describe the body position. - * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", - "surprised", "calm". - * `clothing`: (String) Describe attire. - * `action`: (String) Describe the action of the human. - * `gender`: (String) Describe the gender of the human. - * `skin_tone_and_texture`: (String) Describe the skin tone and texture. - * If the object is a cluster of objects, include the following: - * `number_of_objects`: (Integer) The number of objects in the cluster. -3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable - background elements that are not part of the `objects` section. -4. `lighting`: (Object) - * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". - * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". - * `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined - shadows", "minimal shadows". -5. `aesthetics`: (Object) - * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the - main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition". - * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". - * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". -6. `photographic_characteristics`: (Object) - * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". - * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". - * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". - * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a - person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed - more closely. Avoid "wide-angle" for people unless specified). -7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative - interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil - sketch"). -8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3 - words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting). -9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a - concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of - concept art for a fantasy video game." -10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this - array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a - poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or - posters unless the user provides them. A rare exception is for universally recognized text that is integral to an - object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array - must be empty. - * `text`: (String) The exact text content provided by the user. NEVER use generic placeholders. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `size`: (String) E.g., "medium", "large", "large within frame". - * `color`: (String) E.g., "red", "blue", "green". - * `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface". - * `appearance_details`: (String) Any other notable visual details.""" - - -@cache -def get_instructions(mode: str) -> Tuple[str, str]: - system_prompts = {} - - system_prompts["Caption"] = """ -You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise -lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided -images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and -guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., -no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or -actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a -person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right -side of the image from the viewer's perspective. The JSON object must contain the following keys precisely: -1. `short_description`: (String) A concise summary of the image content, 200 words maximum. -2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the - background. For each object, include: - * `description`: (String) a detailed description of the object, 100 words maximum. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `relative_size`: (String) E.g., "small", "medium", "large within frame". - * `shape_and_color`: (String) Describe the basic shape and dominant color. - * `texture`: (String) E.g., "smooth", "rough", "metallic", "furry". - * `appearance_details`: (String) Any other notable visual details. - * `relationship`: (String) Describe the relationship between the object and the other objects in the image. - * `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45 - degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side". - if the object is a human or a human-like object, include the following: - * `pose`: (String) Describe the body position. - * `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious", - "surprised", "calm". - * `clothing`: (String) Describe attire. - * `action`: (String) Describe the action of the human. - * `gender`: (String) Describe the gender of the human. - * `skin_tone_and_texture`: (String) Describe the skin tone and texture. - if the object is a cluster of objects, include the following: - * `number_of_objects`: (Integer) The number of objects in the cluster. -3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable - background elements that are not part of the objects section. -4. `lighting`: (Object) - * `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour". - * `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left". - * `shadows`: (String) Describe the presence of shadows. -5. `aesthetics`: (Object) - * `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". - * `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast". - * `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful". -6. `photographic_characteristics`: (Object) - * `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background". - * `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur". - * `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle". - * `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". -7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor", - "3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please - describe the specific artistic characteristics under 'artistic_style', 50 words maximum. -8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum. -9. `context`: (String) Provide any additional context that helps understand the image better. This should include a - general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature - Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a - broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine - spread" -10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text - render, include: - * `text`: (String) The text content. - * `location`: (String) E.g., "center", "top-left", "bottom-right foreground". - * `size`: (String) E.g., "small", "medium", "large within frame". - * `color`: (String) E.g., "red", "blue", "green". - * `font`: (String) E.g., "realistic", "cartoonish", "minimalist". - * `appearance_details`: (String) Any other notable visual details. -Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields. -""" - - system_prompts[ - "Generate" - ] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company. - -Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically -compelling visual scene. - -Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that -describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly -mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming -for a high-quality, visually appealing result unless the user's prompt suggests otherwise. - -Adhere strictly to the following structure and guidelines: - -The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is -the JSON:", no explanations, no apologies). - -IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN -PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own -perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's -perspective. - -RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a -composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus -(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small, -unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain"). - -Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST -default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style` -should be "realistic" or a similar descriptor. - -The JSON object must contain the following keys precisely: - -{json_schema_full} - -Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between -fields.""" - - system_prompts[ - "RefineA" - ] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company. - -Your expertise is in refining and modifying existing visual concepts based on precise feedback. - -Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction -for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested -changes. - -Adhere strictly to the following structure and guidelines: - -1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction. -2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text - before or after the JSON object. -3. **Modification Logic:** - * Carefully parse the user's textual instruction to understand the desired changes. - * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. - * All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit - them. -4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example: - * If the instruction is to "change the background to a snowy forest," you must update the `background_setting` - field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need - to change to "serene" or "wintry." - * If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render` - array. - * If the instruction is to "make the person smile," you must update the `expression` field for that object and - potentially update the overall `mood_atmosphere`. -5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts[ - "RefineB" - ] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company. - -Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON -metadata, and textual feedback simultaneously. - -Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction -for a modification. You must use the image as the primary source of truth to understand the context of the requested -change and then generate a new, updated JSON object that accurately reflects that change. - -Adhere strictly to the following structure and guidelines: - -1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction. -2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of - the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make - the car blue," visually locate the car in the image to inform your edits to the JSON. -3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text - before or after the JSON object. -4. **Modification Logic:** - * Analyze the user's textual instruction in the context of what you see in the image. - * Modify ONLY the fields in the JSON that are directly or logically affected by the instruction. - * All other fields not relevant to the change must be copied exactly from the original JSON. -5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual - change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update - the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`. -6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts[ - "InspireA" - ] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company. - -Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct -a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual -instructions. - -Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract -key visual information and then generate a comprehensive JSON object describing a new scene that creatively -incorporates the user's instructions. - -Adhere strictly to the following structure and guidelines: - -1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON. -2. **Core Logic (Analyze and Synthesize):** - * **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific - poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium. - * **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference - and what to change. You will then construct a brand new JSON object from scratch that describes the desired final - scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog - from the reference image in the `objects` array but create a new `background_setting` for a beach, with - appropriate `lighting` and `mood_atmosphere`. -3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not - describe the original reference image. -4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be - reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the - `short_description`. -5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below. - -The JSON object must contain the following keys precisely: - -{json_schema_full}""" - - system_prompts["InspireB"] = system_prompts["Caption"] - - final_prompts = {} - - final_prompts["Generate"] = ( - "Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}." - ) - - final_prompts["RefineA"] = """ - [EXISTING JSON]: {json_data} - - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules - for modification, consistency, and formatting. - """ - - final_prompts["RefineB"] = """ - [EXISTING JSON]: {json_data} - - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that - incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and - consistency. - """ - - final_prompts["InspireA"] = """ - [EDIT INSTRUCTIONS]: {user_prompt} - - [TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose) - and then generate a new, detailed JSON object for the scene described in the instructions above. Do not - describe the reference image itself; describe the new scene. Follow all of your system rules. - """ - - final_prompts["Caption"] = ( - "Analyze the provided image and generate the detailed JSON object as specified in your instructions." - ) - final_prompts["InspireB"] = final_prompts["Caption"] - - return system_prompts.get(mode, ""), final_prompts.get(mode, "") - - -def keep(p, k, v): - is_none = v is None - is_empty_string = isinstance(v, str) and v == "" - is_empty_dict = isinstance(v, dict) and not v - is_empty_list = isinstance(v, list) and not v - is_nan = isinstance(v, float) and math.isnan(v) - if is_none or is_empty_string or is_empty_list or is_empty_dict: - return False - if is_nan: - return False - return True - - -def validate_json(json_data: dict) -> dict: - ia = ImageAnalysis.model_validate_json(json_data, strict=True) - return ia.model_dump(exclude_none=True) - - -def validate_structured_prompt_str(structured_prompt_str: str) -> str: - ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True) - c = ia.model_dump(exclude_none=True) - return json.dumps(c) - - -def prepare_clean_caption(json_dump: dict) -> str: - # filter empty values recursivly (i.e. None, "", {}, [], float("nan")) - clean_caption_dict = remap(json_dump, visit=keep) - - scores = {"preference_score": "very high", "aesthetic_score": "very high"} - # Set aesthetics scores - if "aesthetics" not in clean_caption_dict: - clean_caption_dict["aesthetics"] = scores - else: - clean_caption_dict["aesthetics"].update(scores) - - # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps) - clean_caption_str = json.dumps(clean_caption_dict) - return clean_caption_str - - -# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024) -# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number). -def resize_image_by_num_pixels( - image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0 -) -> Image.Image: - if target_ratio != 0.0: - ratio = target_ratio - else: - ratio = image.size[0] / image.size[1] - width = int((pixel_number * ratio) ** 0.5) - width = width - (width % granularity_val) - height = int(pixel_number / width) - height = height - (height % granularity_val) - return image.resize((width, height)) - - -def infer_with_gemini( - client: genai.Client, - final_prompt: str, - system_prompt: str, - top_p: float, - temperature: float, - max_tokens: int, - seed: int = 42, - image: Optional[Image.Image] = None, - model: str = "gemini-2.5-flash", -) -> str: - """ - Calls Gemini API with the given prompt and returns the raw JSON response. - - Args: - final_prompt: The text prompt to send to Gemini - system_prompt: The system instruction for Gemini - existing_image_path: Optional path to an image file to include - model: The Gemini model to use - - Returns: - Raw JSON response text from Gemini - """ - parts = [{"text": final_prompt}] - if image: - # Save image into bytes - image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect - less_then = 262144 - if image.size[0] * image.size[1] > less_then: - image = resize_image_by_num_pixels( - image, pixel_number=less_then, granularity_val=1, target_ratio=0.0 - ) # 512x512 - buffer = io.BytesIO() - image.save(buffer, format="JPEG") - image_bytes = buffer.getvalue() - - img_part = { - "inlineData": { - "data": image_bytes, - "mimeType": "image/jpeg", - } - } - parts.append(img_part) - - contents = [{"role": "user", "parts": parts}] - - generationConfig = { - "temperature": temperature, - "topP": top_p, - "maxOutputTokens": max_tokens, - "response_mime_type": "application/json", - "response_schema": get_gemini_output_schema(), - "system_instruction": system_prompt, # len 5900 - "thinkingConfig": {"thinkingBudget": 0}, - "seed": seed, - } - - response = client.models.generate_content( - model=model, - contents=contents, - config=generationConfig, - ) - - if response.candidates[0].finish_reason == "MAX_TOKENS": - raise Exception("Max tokens") - - return response.candidates[0].content.parts[0].text - - -def get_default_negative_prompt(existing_json: dict) -> str: - negative_prompt = "" - style_medium = existing_json.get("style_medium", "").lower() - if style_medium in ["photograph", "photography", "photo"]: - negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}""" - return negative_prompt - - -def json_promptify( - client: genai.Client, - model_id: str, - top_p: float, - temperature: float, - max_tokens: int, - user_prompt: Optional[str] = None, - existing_json: Optional[str] = None, - image: Optional[Image.Image] = None, - seed: int = 42, -) -> str: - if existing_json: - # make sure aesthetic scores are not in the existing json (will be added later) - existing_json = json.loads(existing_json) - if "aesthetics" in existing_json: - existing_json["aesthetics"].pop("aesthetic_score", None) - existing_json["aesthetics"].pop("preference_score", None) - existing_json = json.dumps(existing_json) - - if not user_prompt: - raise ValueError("user_prompt is required if existing_json is provided") - - if image: - mode = "RefineB" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) - - else: - mode = "RefineA" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json) - elif image and user_prompt: - mode = "InspireA" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt) - elif image and not user_prompt: - mode = "Caption" - system_prompt, final_prompt = get_instructions(mode) - else: - mode = "Generate" - system_prompt, final_prompt = get_instructions(mode) - final_prompt = final_prompt.format(user_prompt=user_prompt) - - json_data = infer_with_gemini( - client=client, - model=model_id, - final_prompt=final_prompt, - system_prompt=system_prompt, - seed=seed, - image=image, - top_p=top_p, - temperature=temperature, - max_tokens=max_tokens, - ) - json_data = validate_json(json_data) - clean_caption = prepare_clean_caption(json_data) - - return clean_caption - - -class BriaFiboGeminiPromptToJson(ModularPipelineBlocks): - model_name = "BriaFibo" - - def __init__(self, model_id="gemini-2.5-flash"): - super().__init__() - api_key = os.getenv("GOOGLE_API_KEY") - if api_key is None: - raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.") - self.model_id = model_id - - @property - def expected_components(self): - return [] - - @property - def inputs(self) -> List[InputParam]: - task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute") - prompt_input = InputParam( - "prompt", - type_hint=str, - required=False, - description="Prompt to use", - ) - image_input = InputParam( - name="image", type_hint=Image.Image, required=False, description="image for inspiration mode" - ) - json_prompt_input = InputParam( - name="json_prompt", type_hint=str, required=False, description="JSON prompt to use" - ) - sampling_top_p_input = InputParam( - name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0 - ) - sampling_temperature_input = InputParam( - name="sampling_temperature", - type_hint=float, - required=False, - description="Sampling temperature", - default=0.2, - ) - sampling_max_tokens_input = InputParam( - name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000 - ) - return [ - task_input, - prompt_input, - image_input, - json_prompt_input, - sampling_top_p_input, - sampling_temperature_input, - sampling_max_tokens_input, - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "json_prompt", - type_hint=str, - description="JSON prompt by the VLM", - ) - ] - - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - prompt = block_state.prompt - image = block_state.image - json_prompt = block_state.json_prompt - client = genai.Client() - json_prompt = json_promptify( - client=client, - model_id=self.model_id, - top_p=block_state.sampling_top_p, - temperature=block_state.sampling_temperature, - max_tokens=block_state.sampling_max_tokens, - user_prompt=prompt, - existing_json=json_prompt, - image=image, - ) - block_state.json_prompt = json_prompt - self.set_block_state(state, block_state) - - return components, state diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 690b54607d..ef86997155 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -1,3 +1,13 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -35,20 +45,47 @@ else: logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ - ... + Example: + ```python + import torch + from diffusers import BriaFiboPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + + pipe = BriaFiboPipeline.from_pretrained( + "briaai/FIBO", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + pipe.enable_sequential_cpu_offload() + + with torch.inference_mode(): + # 1. Create a prompt to generate an initial image + output = vlm_pipe(prompt="a beautiful dog") + json_prompt_generate = output.values["json_prompt"] + + # Generate the image from the structured json prompt + results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5) + results_generate.images[0].save("image_generate.png") + ``` """ class BriaFiboPipeline(DiffusionPipeline): r""" Args: - transformer ([`GaiaTransformer2DModel`]): - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`SmolLM3ForCausalLM`]): + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" @@ -166,7 +203,7 @@ class BriaFiboPipeline(DiffusionPipeline): prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, + guidance_scale: float = 5, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -181,8 +218,8 @@ class BriaFiboPipeline(DiffusionPipeline): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not + guidance_scale (`float`): + Guidance scale for classifier free guidance. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -224,7 +261,7 @@ class BriaFiboPipeline(DiffusionPipeline): prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] - if do_classifier_free_guidance: + if guidance_scale > 1: if isinstance(negative_prompt, list) and negative_prompt[0] is None: negative_prompt = "" negative_prompt = negative_prompt or "" @@ -302,9 +339,6 @@ class BriaFiboPipeline(DiffusionPipeline): # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 @property def joint_attention_kwargs(self): @@ -320,6 +354,7 @@ class BriaFiboPipeline(DiffusionPipeline): @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor @@ -366,6 +401,7 @@ class BriaFiboPipeline(DiffusionPipeline): @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -410,34 +446,7 @@ class BriaFiboPipeline(DiffusionPipeline): return latents, latent_image_ids @staticmethod - def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None): - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - - assert height % 16 == 0 and width % 16 == 0 - - mu = calculate_shift( - image_seq_len, - noise_scheduler.config.base_image_seq_len, - noise_scheduler.config.max_image_seq_len, - noise_scheduler.config.base_shift, - noise_scheduler.config.max_shift, - ) - - # Init sigmas and timesteps according to shift size - # This changes the scheduler in-place according to the dynamic scheduling - timesteps, num_inference_steps = retrieve_timesteps( - noise_scheduler, - num_inference_steps=num_inference_steps, - device=device, - timesteps=None, - sigmas=sigmas, - mu=mu, - ) - - return noise_scheduler, timesteps, num_inference_steps, mu - - @staticmethod - def create_attention_matrix(attention_mask): + def _prepare_attention_mask(attention_mask): attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) # convert to 0 - keep, -inf ignore @@ -583,7 +592,7 @@ class BriaFiboPipeline(DiffusionPipeline): ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, + guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, device=device, @@ -593,7 +602,7 @@ class BriaFiboPipeline(DiffusionPipeline): ) prompt_batch_size = prompt_embeds.shape[0] - if self.do_classifier_free_guidance: + if guidance_scale > 1: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_layers = [ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) @@ -611,6 +620,7 @@ class BriaFiboPipeline(DiffusionPipeline): prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels if do_patching: num_channels_latents = int(num_channels_latents / 4) @@ -630,11 +640,11 @@ class BriaFiboPipeline(DiffusionPipeline): latent_attention_mask = torch.ones( [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device ) - if self.do_classifier_free_guidance: + if guidance_scale > 1: latent_attention_mask = latent_attention_mask.repeat(2, 1) attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) - attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting if self._joint_attention_kwargs is None: @@ -648,13 +658,25 @@ class BriaFiboPipeline(DiffusionPipeline): else: seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) - self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler( - height=height, - width=width, - device=device, + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps=num_inference_steps, - noise_scheduler=self.scheduler, - image_seq_len=seq_len, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -674,10 +696,7 @@ class BriaFiboPipeline(DiffusionPipeline): continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - - if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to( @@ -697,7 +716,7 @@ class BriaFiboPipeline(DiffusionPipeline): )[0] # perform guidance - if self.do_classifier_free_guidance: + if guidance_scale > 1: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -736,7 +755,6 @@ class BriaFiboPipeline(DiffusionPipeline): else: latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) - latents = latents.to(dtype=self.vae.dtype) latents = latents.unsqueeze(dim=2) latents = list(torch.unbind(latents, dim=0)) latents_device = latents[0].device @@ -780,8 +798,8 @@ class BriaFiboPipeline(DiffusionPipeline): callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -818,9 +836,3 @@ class BriaFiboPipeline(DiffusionPipeline): if max_sequence_length is not None and max_sequence_length > 3000: raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") - - def to(self, *args, **kwargs): - DiffusionPipeline.to(self, *args, **kwargs) - # We use as float32 since wan22 in their repo use it like this - self.vae.to(dtype=torch.float32) - return self diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py index ad87b5710a..f859f4608b 100644 --- a/tests/models/transformers/test_models_transformer_bria_fibo.py +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -20,7 +20,7 @@ import torch from diffusers import BriaFiboTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() @@ -84,48 +84,6 @@ class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output_1 = model(**inputs_dict)[0] - - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) - text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) - image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) - - assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" - assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" - - inputs_dict["txt_ids"] = text_ids_3d - inputs_dict["img_ids"] = image_ids_3d - - with torch.no_grad(): - output_2 = model(**inputs_dict)[0] - - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = {"BriaFiboTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - -class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = BriaFiboTransformer2DModel - - def prepare_init_args_and_inputs_for_common(self): - return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() - - -class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = BriaFiboTransformer2DModel - - def prepare_init_args_and_inputs_for_common(self): - return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py index 15cdb82fe1..969634f597 100644 --- a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -128,7 +128,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipe = pipe.to(torch_device) inputs = self.get_dummy_inputs(torch_device) - height_width_pairs = [(32, 32)] + height_width_pairs = [(32, 32), (64, 64), (32, 64)] for height, width in height_width_pairs: expected_height = height expected_width = width @@ -181,13 +181,18 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." ) - def test_to_dtype(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) + # def test_to_dtype(self): + # components = self.get_dummy_components() + # pipe = self.pipeline_class(**components) + # pipe.set_progress_bar_config(disable=None) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + # self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + # pipe.to(dtype=torch.float16) + # model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + # self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + @unittest.skip("") def test_save_load_dduf(self): pass