mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix CR
This commit is contained in:
@@ -323,6 +323,8 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/transformer_bria_fibo
|
||||
title: BriaFiboTransformer2DModel
|
||||
- local: api/models/bria_transformer
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
@@ -469,6 +471,8 @@
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...models.transformers.transformer_flux import FluxTransformerBlock
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
@@ -20,76 +28,193 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a
|
||||
frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta'
|
||||
parameter scales the frequencies. The returned tensor contains complex values in complex64 data type.
|
||||
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
class EmbedND(torch.nn.Module):
|
||||
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class BriaFiboAttnProcessor:
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "BriaFiboAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttention
|
||||
_default_processor_cls = BriaFiboAttnProcessor
|
||||
_available_processors = [
|
||||
BriaFiboAttnProcessor,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
context_pre_only: Optional[bool] = None,
|
||||
pre_only: bool = False,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class FIBOEmbedND(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
@@ -182,7 +307,93 @@ class TextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
@maybe_allow_in_graph
|
||||
class BriaFiboTransformerBlock(nn.Module):
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = BriaFiboAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=BriaFiboAttnProcessor(),
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FIBOTimesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
@@ -209,7 +420,7 @@ class TimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
self.time_proj = FIBOTimesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
@@ -258,7 +469,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
@@ -270,7 +481,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
BriaFiboTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"]
|
||||
_import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson
|
||||
from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,373 +0,0 @@
|
||||
import json
|
||||
import math
|
||||
import textwrap
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from boltons.iterutils import remap
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
|
||||
from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
|
||||
|
||||
|
||||
def parse_aesthetic_score(record: dict) -> str:
|
||||
ae = record["aesthetic_score"]
|
||||
if ae < 5.5:
|
||||
return "very low"
|
||||
elif ae < 6:
|
||||
return "low"
|
||||
elif ae < 7:
|
||||
return "medium"
|
||||
elif ae < 7.6:
|
||||
return "high"
|
||||
else:
|
||||
return "very high"
|
||||
|
||||
|
||||
def parse_pickascore(record: dict) -> str:
|
||||
ps = record["pickascore"]
|
||||
if ps < 0.78:
|
||||
return "very low"
|
||||
elif ps < 0.82:
|
||||
return "low"
|
||||
elif ps < 0.87:
|
||||
return "medium"
|
||||
elif ps < 0.91:
|
||||
return "high"
|
||||
else:
|
||||
return "very high"
|
||||
|
||||
|
||||
def prepare_clean_caption(record: dict) -> str:
|
||||
def keep(p, k, v):
|
||||
is_none = v is None
|
||||
is_empty_string = isinstance(v, str) and v == ""
|
||||
is_empty_dict = isinstance(v, dict) and not v
|
||||
is_empty_list = isinstance(v, list) and not v
|
||||
is_nan = isinstance(v, float) and math.isnan(v)
|
||||
if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
scores = {}
|
||||
if "pickascore" in record:
|
||||
scores["preference_score"] = parse_pickascore(record)
|
||||
if "aesthetic_score" in record:
|
||||
scores["aesthetic_score"] = parse_aesthetic_score(record)
|
||||
|
||||
clean_caption_dict = remap(record, visit=keep)
|
||||
|
||||
# Set aesthetics scores
|
||||
if "aesthetics" not in clean_caption_dict:
|
||||
if len(scores) > 0:
|
||||
clean_caption_dict["aesthetics"] = scores
|
||||
else:
|
||||
clean_caption_dict["aesthetics"].update(scores)
|
||||
|
||||
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
|
||||
clean_caption_str = json.dumps(clean_caption_dict)
|
||||
return clean_caption_str
|
||||
except Exception as ex:
|
||||
print("Error: ", ex)
|
||||
raise ex
|
||||
|
||||
|
||||
def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]:
|
||||
images: List[Image.Image] = []
|
||||
for message in messages:
|
||||
content = message.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "image":
|
||||
continue
|
||||
image_value = item.get("image")
|
||||
if isinstance(image_value, Image.Image):
|
||||
images.append(image_value)
|
||||
else:
|
||||
raise ValueError("Expected PIL.Image for image content in messages.")
|
||||
return images
|
||||
|
||||
|
||||
def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str:
|
||||
if not stop_sequences:
|
||||
return text.strip()
|
||||
cleaned = text
|
||||
for stop in stop_sequences:
|
||||
if not stop:
|
||||
continue
|
||||
index = cleaned.find(stop)
|
||||
if index >= 0:
|
||||
cleaned = cleaned[:index]
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
class TransformersEngine(torch.nn.Module):
|
||||
"""Inference wrapper using Hugging Face transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super(TransformersEngine, self).__init__()
|
||||
default_processor_kwargs: Dict[str, Any] = {
|
||||
"min_pixels": 256 * 28 * 28,
|
||||
"max_pixels": 1024 * 28 * 28,
|
||||
}
|
||||
processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})}
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs)
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model,
|
||||
dtype=torch.bfloat16,
|
||||
**model_kwargs,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
tokenizer_obj = self.processor.tokenizer
|
||||
if tokenizer_obj.pad_token_id is None:
|
||||
tokenizer_obj.pad_token = tokenizer_obj.eos_token
|
||||
self._pad_token_id = tokenizer_obj.pad_token_id
|
||||
eos_token_id = tokenizer_obj.eos_token_id
|
||||
if isinstance(eos_token_id, list) and eos_token_id:
|
||||
self._eos_token_id = eos_token_id
|
||||
elif eos_token_id is not None:
|
||||
self._eos_token_id = [eos_token_id]
|
||||
else:
|
||||
raise ValueError("Tokenizer must define an EOS token for generation.")
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def _to_model_device(self, value: Any) -> Any:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return value
|
||||
target_device = getattr(self.model, "device", None)
|
||||
if target_device is None or target_device.type == "meta":
|
||||
return value
|
||||
if value.device == target_device:
|
||||
return value
|
||||
return value.to(target_device)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
tokenizer = self.processor.tokenizer
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
processor_inputs: Dict[str, Any] = {
|
||||
"text": [prompt_text],
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
images = _collect_images(messages)
|
||||
if images:
|
||||
processor_inputs["images"] = images
|
||||
inputs = self.processor(**processor_inputs)
|
||||
inputs = {key: self._to_model_device(value) for key, value in inputs.items()}
|
||||
|
||||
generation_kwargs: Dict[str, Any] = {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": temperature > 0,
|
||||
"eos_token_id": self._eos_token_id,
|
||||
"pad_token_id": self._pad_token_id,
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
generated_ids = self.model.generate(**inputs, **generation_kwargs)
|
||||
|
||||
input_ids = inputs.get("input_ids")
|
||||
if input_ids is None:
|
||||
raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.")
|
||||
new_token_ids = generated_ids[:, input_ids.shape[-1] :]
|
||||
decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
|
||||
if not decoded:
|
||||
return ""
|
||||
text = decoded[0]
|
||||
stripped_text = _strip_stop_sequences(text, stop)
|
||||
json_prompt = json.loads(stripped_text)
|
||||
return json_prompt
|
||||
|
||||
|
||||
def generate_json_prompt(
|
||||
vlm_processor: AutoModelForCausalLM,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: List[str],
|
||||
image: Optional[Image.Image] = None,
|
||||
prompt: Optional[str] = None,
|
||||
structured_prompt: Optional[str] = None,
|
||||
):
|
||||
if image is None and structured_prompt is None:
|
||||
# only got prompt
|
||||
task = "generate"
|
||||
editing_instructions = None
|
||||
elif image is None and structured_prompt is not None and prompt is not None:
|
||||
# got structured prompt and prompt
|
||||
task = "refine"
|
||||
editing_instructions = prompt
|
||||
elif image is not None and structured_prompt is None and prompt is not None:
|
||||
# got image and prompt
|
||||
task = "refine"
|
||||
editing_instructions = prompt
|
||||
elif image is not None and structured_prompt is None and prompt is None:
|
||||
# only got image
|
||||
task = "inspire"
|
||||
editing_instructions = None
|
||||
else:
|
||||
raise ValueError("Invalid input")
|
||||
|
||||
messages = build_messages(
|
||||
task,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
structured_prompt=structured_prompt,
|
||||
editing_instructions=editing_instructions,
|
||||
)
|
||||
|
||||
generated_prompt = vlm_processor.generate(
|
||||
messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop
|
||||
)
|
||||
cleaned_json_data = prepare_clean_caption(generated_prompt)
|
||||
return cleaned_json_data
|
||||
|
||||
|
||||
def build_messages(
|
||||
task: str,
|
||||
*,
|
||||
image: Optional[Image.Image] = None,
|
||||
refine_image: Optional[Image.Image] = None,
|
||||
prompt: Optional[str] = None,
|
||||
structured_prompt: Optional[str] = None,
|
||||
editing_instructions: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
user_content: List[Dict[str, Any]] = []
|
||||
|
||||
if task == "inspire":
|
||||
user_content.append({"type": "image", "image": image})
|
||||
user_content.append({"type": "text", "text": "<inspire>"})
|
||||
elif task == "generate":
|
||||
text_value = (prompt or "").strip()
|
||||
formatted = f"<generate>\n{text_value}"
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
else: # refine
|
||||
if refine_image is None:
|
||||
base_prompt = (structured_prompt or "").strip()
|
||||
edits = (editing_instructions or "").strip()
|
||||
formatted = textwrap.dedent(f"""<refine> Input: {base_prompt} Editing instructions: {edits}""").strip()
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
else:
|
||||
user_content.append({"type": "image", "image": refine_image})
|
||||
edits = (editing_instructions or "").strip()
|
||||
formatted = textwrap.dedent(f"""<refine> Editing instructions: {edits}""").strip()
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
return messages
|
||||
|
||||
|
||||
class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
|
||||
model_name = "BriaFibo"
|
||||
|
||||
def __init__(self, model_id):
|
||||
super().__init__()
|
||||
self.engine = TransformersEngine(model_id)
|
||||
self.engine.model.to("cuda")
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
prompt_input = InputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
required=False,
|
||||
description="Prompt to use",
|
||||
)
|
||||
image_input = InputParam(
|
||||
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
|
||||
)
|
||||
json_prompt_input = InputParam(
|
||||
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
|
||||
)
|
||||
sampling_top_p_input = InputParam(
|
||||
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9
|
||||
)
|
||||
sampling_temperature_input = InputParam(
|
||||
name="sampling_temperature",
|
||||
type_hint=float,
|
||||
required=False,
|
||||
description="Sampling temperature",
|
||||
default=0.2,
|
||||
)
|
||||
sampling_max_tokens_input = InputParam(
|
||||
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096
|
||||
)
|
||||
return [
|
||||
prompt_input,
|
||||
image_input,
|
||||
json_prompt_input,
|
||||
sampling_top_p_input,
|
||||
sampling_temperature_input,
|
||||
sampling_max_tokens_input,
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"json_prompt",
|
||||
type_hint=str,
|
||||
description="JSON prompt by the VLM",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt = block_state.prompt
|
||||
image = block_state.image
|
||||
json_prompt = block_state.json_prompt
|
||||
block_state.json_prompt = generate_json_prompt(
|
||||
vlm_processor=self.engine,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
structured_prompt=json_prompt,
|
||||
top_p=block_state.sampling_top_p,
|
||||
temperature=block_state.sampling_temperature,
|
||||
max_tokens=block_state.sampling_max_tokens,
|
||||
stop=["<|im_end|>", "<|end_of_text|>"],
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -1,804 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import remap
|
||||
from google import genai
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState
|
||||
|
||||
|
||||
class ObjectDescription(BaseModel):
|
||||
description: str = Field(..., description="Short description of the object.")
|
||||
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
|
||||
relationship: str = Field(
|
||||
..., description="Describe the relationship between the object and the other objects in the image."
|
||||
)
|
||||
relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.")
|
||||
shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.")
|
||||
texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.")
|
||||
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
|
||||
# If cluster of object
|
||||
number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.")
|
||||
# Human-specific fields
|
||||
pose: Optional[str] = Field(None, description="Describe the body position.")
|
||||
expression: Optional[str] = Field(None, description="Describe facial expression.")
|
||||
clothing: Optional[str] = Field(None, description="Describe attire.")
|
||||
action: Optional[str] = Field(None, description="Describe the action of the human.")
|
||||
gender: Optional[str] = Field(None, description="Describe the gender of the human.")
|
||||
skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.")
|
||||
orientation: Optional[str] = Field(None, description="Describe the orientation of the human.")
|
||||
|
||||
|
||||
class LightingDetails(BaseModel):
|
||||
conditions: str = Field(
|
||||
..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'."
|
||||
)
|
||||
direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.")
|
||||
shadows: Optional[str] = Field(None, description="Describe the presence of shadows.")
|
||||
|
||||
|
||||
class AestheticsDetails(BaseModel):
|
||||
composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.")
|
||||
color_scheme: str = Field(
|
||||
..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'."
|
||||
)
|
||||
mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.")
|
||||
|
||||
|
||||
class PhotographicCharacteristicsDetails(BaseModel):
|
||||
depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.")
|
||||
focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.")
|
||||
camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.")
|
||||
lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.")
|
||||
|
||||
|
||||
class TextRender(BaseModel):
|
||||
text: str = Field(..., description="The text content.")
|
||||
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
|
||||
size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.")
|
||||
color: str = Field(..., description="E.g., 'red', 'blue', 'green'.")
|
||||
font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.")
|
||||
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
|
||||
|
||||
|
||||
class ImageAnalysis(BaseModel):
|
||||
short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.")
|
||||
objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.")
|
||||
background_setting: str = Field(
|
||||
...,
|
||||
description="Describe the overall environment, setting, or background, including any notable background elements.",
|
||||
)
|
||||
lighting: LightingDetails = Field(..., description="Details about the lighting.")
|
||||
aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.")
|
||||
photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field(
|
||||
None, description="Details about photographic characteristics."
|
||||
)
|
||||
style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.")
|
||||
text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.")
|
||||
context: str = Field(..., description="Provide any additional context that helps understand the image better.")
|
||||
artistic_style: Optional[str] = Field(
|
||||
None, description="describe specific artistic characteristics, 3 words maximum."
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_output_schema() -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"short_description": {"type": "STRING"},
|
||||
"objects": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"description": {"type": "STRING"},
|
||||
"location": {"type": "STRING"},
|
||||
"relationship": {"type": "STRING"},
|
||||
"relative_size": {"type": "STRING"},
|
||||
"shape_and_color": {"type": "STRING"},
|
||||
"texture": {"nullable": True, "type": "STRING"},
|
||||
"appearance_details": {"nullable": True, "type": "STRING"},
|
||||
"number_of_objects": {"nullable": True, "type": "INTEGER"},
|
||||
"pose": {"nullable": True, "type": "STRING"},
|
||||
"expression": {"nullable": True, "type": "STRING"},
|
||||
"clothing": {"nullable": True, "type": "STRING"},
|
||||
"action": {"nullable": True, "type": "STRING"},
|
||||
"gender": {"nullable": True, "type": "STRING"},
|
||||
"skin_tone_and_texture": {"nullable": True, "type": "STRING"},
|
||||
"orientation": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"description",
|
||||
"location",
|
||||
"relationship",
|
||||
"relative_size",
|
||||
"shape_and_color",
|
||||
"texture",
|
||||
"appearance_details",
|
||||
"number_of_objects",
|
||||
"pose",
|
||||
"expression",
|
||||
"clothing",
|
||||
"action",
|
||||
"gender",
|
||||
"skin_tone_and_texture",
|
||||
"orientation",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"type": "ARRAY",
|
||||
},
|
||||
"background_setting": {"type": "STRING"},
|
||||
"lighting": {
|
||||
"properties": {
|
||||
"conditions": {"type": "STRING"},
|
||||
"direction": {"type": "STRING"},
|
||||
"shadows": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": ["conditions", "direction", "shadows"],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"aesthetics": {
|
||||
"properties": {
|
||||
"composition": {"type": "STRING"},
|
||||
"color_scheme": {"type": "STRING"},
|
||||
"mood_atmosphere": {"type": "STRING"},
|
||||
},
|
||||
"required": ["composition", "color_scheme", "mood_atmosphere"],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"photographic_characteristics": {
|
||||
"nullable": True,
|
||||
"properties": {
|
||||
"depth_of_field": {"type": "STRING"},
|
||||
"focus": {"type": "STRING"},
|
||||
"camera_angle": {"type": "STRING"},
|
||||
"lens_focal_length": {"type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"depth_of_field",
|
||||
"focus",
|
||||
"camera_angle",
|
||||
"lens_focal_length",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"style_medium": {"type": "STRING"},
|
||||
"text_render": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"text": {"type": "STRING"},
|
||||
"location": {"type": "STRING"},
|
||||
"size": {"type": "STRING"},
|
||||
"color": {"type": "STRING"},
|
||||
"font": {"type": "STRING"},
|
||||
"appearance_details": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"text",
|
||||
"location",
|
||||
"size",
|
||||
"color",
|
||||
"font",
|
||||
"appearance_details",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"type": "ARRAY",
|
||||
},
|
||||
"context": {"type": "STRING"},
|
||||
"artistic_style": {"type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"short_description",
|
||||
"objects",
|
||||
"background_setting",
|
||||
"lighting",
|
||||
"aesthetics",
|
||||
"photographic_characteristics",
|
||||
"style_medium",
|
||||
"text_render",
|
||||
"context",
|
||||
"artistic_style",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
}
|
||||
|
||||
|
||||
json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum.
|
||||
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively
|
||||
choose the most important ones and describe the rest in the background. For each object, include:
|
||||
* `description`: (String) A detailed description of the imagined object, 100 words maximum.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this
|
||||
should be "medium-to-large" or "large within frame").
|
||||
* `shape_and_color`: (String) Describe the basic shape and dominant color.
|
||||
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
|
||||
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
|
||||
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
|
||||
* If the object is a human or a human-like object, include the following:
|
||||
* `pose`: (String) Describe the body position.
|
||||
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
|
||||
"surprised", "calm".
|
||||
* `clothing`: (String) Describe attire.
|
||||
* `action`: (String) Describe the action of the human.
|
||||
* `gender`: (String) Describe the gender of the human.
|
||||
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
|
||||
* If the object is a cluster of objects, include the following:
|
||||
* `number_of_objects`: (Integer) The number of objects in the cluster.
|
||||
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
|
||||
background elements that are not part of the `objects` section.
|
||||
4. `lighting`: (Object)
|
||||
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
|
||||
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
|
||||
* `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined
|
||||
shadows", "minimal shadows".
|
||||
5. `aesthetics`: (Object)
|
||||
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the
|
||||
main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition".
|
||||
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
|
||||
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
|
||||
6. `photographic_characteristics`: (Object)
|
||||
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
|
||||
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
|
||||
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
|
||||
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a
|
||||
person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed
|
||||
more closely. Avoid "wide-angle" for people unless specified).
|
||||
7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative
|
||||
interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil
|
||||
sketch").
|
||||
8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3
|
||||
words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting).
|
||||
9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a
|
||||
concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of
|
||||
concept art for a fantasy video game."
|
||||
10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this
|
||||
array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a
|
||||
poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or
|
||||
posters unless the user provides them. A rare exception is for universally recognized text that is integral to an
|
||||
object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array
|
||||
must be empty.
|
||||
* `text`: (String) The exact text content provided by the user. NEVER use generic placeholders.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `size`: (String) E.g., "medium", "large", "large within frame".
|
||||
* `color`: (String) E.g., "red", "blue", "green".
|
||||
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface".
|
||||
* `appearance_details`: (String) Any other notable visual details."""
|
||||
|
||||
|
||||
@cache
|
||||
def get_instructions(mode: str) -> Tuple[str, str]:
|
||||
system_prompts = {}
|
||||
|
||||
system_prompts["Caption"] = """
|
||||
You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise
|
||||
lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided
|
||||
images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and
|
||||
guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g.,
|
||||
no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or
|
||||
actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a
|
||||
person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right
|
||||
side of the image from the viewer's perspective. The JSON object must contain the following keys precisely:
|
||||
1. `short_description`: (String) A concise summary of the image content, 200 words maximum.
|
||||
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the
|
||||
background. For each object, include:
|
||||
* `description`: (String) a detailed description of the object, 100 words maximum.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `relative_size`: (String) E.g., "small", "medium", "large within frame".
|
||||
* `shape_and_color`: (String) Describe the basic shape and dominant color.
|
||||
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
|
||||
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
|
||||
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
|
||||
if the object is a human or a human-like object, include the following:
|
||||
* `pose`: (String) Describe the body position.
|
||||
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
|
||||
"surprised", "calm".
|
||||
* `clothing`: (String) Describe attire.
|
||||
* `action`: (String) Describe the action of the human.
|
||||
* `gender`: (String) Describe the gender of the human.
|
||||
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
|
||||
if the object is a cluster of objects, include the following:
|
||||
* `number_of_objects`: (Integer) The number of objects in the cluster.
|
||||
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
|
||||
background elements that are not part of the objects section.
|
||||
4. `lighting`: (Object)
|
||||
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
|
||||
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
|
||||
* `shadows`: (String) Describe the presence of shadows.
|
||||
5. `aesthetics`: (Object)
|
||||
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines".
|
||||
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
|
||||
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
|
||||
6. `photographic_characteristics`: (Object)
|
||||
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
|
||||
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
|
||||
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
|
||||
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye".
|
||||
7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor",
|
||||
"3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please
|
||||
describe the specific artistic characteristics under 'artistic_style', 50 words maximum.
|
||||
8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum.
|
||||
9. `context`: (String) Provide any additional context that helps understand the image better. This should include a
|
||||
general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature
|
||||
Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a
|
||||
broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine
|
||||
spread"
|
||||
10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text
|
||||
render, include:
|
||||
* `text`: (String) The text content.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `size`: (String) E.g., "small", "medium", "large within frame".
|
||||
* `color`: (String) E.g., "red", "blue", "green".
|
||||
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields.
|
||||
"""
|
||||
|
||||
system_prompts[
|
||||
"Generate"
|
||||
] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company.
|
||||
|
||||
Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically
|
||||
compelling visual scene.
|
||||
|
||||
Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that
|
||||
describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly
|
||||
mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming
|
||||
for a high-quality, visually appealing result unless the user's prompt suggests otherwise.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is
|
||||
the JSON:", no explanations, no apologies).
|
||||
|
||||
IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN
|
||||
PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own
|
||||
perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's
|
||||
perspective.
|
||||
|
||||
RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a
|
||||
composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus
|
||||
(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small,
|
||||
unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain").
|
||||
|
||||
Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST
|
||||
default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style`
|
||||
should be "realistic" or a similar descriptor.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}
|
||||
|
||||
Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between
|
||||
fields."""
|
||||
|
||||
system_prompts[
|
||||
"RefineA"
|
||||
] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company.
|
||||
|
||||
Your expertise is in refining and modifying existing visual concepts based on precise feedback.
|
||||
|
||||
Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction
|
||||
for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested
|
||||
changes.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction.
|
||||
2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
|
||||
before or after the JSON object.
|
||||
3. **Modification Logic:**
|
||||
* Carefully parse the user's textual instruction to understand the desired changes.
|
||||
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
|
||||
* All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit
|
||||
them.
|
||||
4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example:
|
||||
* If the instruction is to "change the background to a snowy forest," you must update the `background_setting`
|
||||
field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need
|
||||
to change to "serene" or "wintry."
|
||||
* If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render`
|
||||
array.
|
||||
* If the instruction is to "make the person smile," you must update the `expression` field for that object and
|
||||
potentially update the overall `mood_atmosphere`.
|
||||
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts[
|
||||
"RefineB"
|
||||
] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company.
|
||||
|
||||
Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON
|
||||
metadata, and textual feedback simultaneously.
|
||||
|
||||
Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction
|
||||
for a modification. You must use the image as the primary source of truth to understand the context of the requested
|
||||
change and then generate a new, updated JSON object that accurately reflects that change.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction.
|
||||
2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of
|
||||
the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make
|
||||
the car blue," visually locate the car in the image to inform your edits to the JSON.
|
||||
3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
|
||||
before or after the JSON object.
|
||||
4. **Modification Logic:**
|
||||
* Analyze the user's textual instruction in the context of what you see in the image.
|
||||
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
|
||||
* All other fields not relevant to the change must be copied exactly from the original JSON.
|
||||
5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual
|
||||
change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update
|
||||
the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`.
|
||||
6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts[
|
||||
"InspireA"
|
||||
] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company.
|
||||
|
||||
Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct
|
||||
a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual
|
||||
instructions.
|
||||
|
||||
Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract
|
||||
key visual information and then generate a comprehensive JSON object describing a new scene that creatively
|
||||
incorporates the user's instructions.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON.
|
||||
2. **Core Logic (Analyze and Synthesize):**
|
||||
* **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific
|
||||
poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium.
|
||||
* **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference
|
||||
and what to change. You will then construct a brand new JSON object from scratch that describes the desired final
|
||||
scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog
|
||||
from the reference image in the `objects` array but create a new `background_setting` for a beach, with
|
||||
appropriate `lighting` and `mood_atmosphere`.
|
||||
3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not
|
||||
describe the original reference image.
|
||||
4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be
|
||||
reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the
|
||||
`short_description`.
|
||||
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts["InspireB"] = system_prompts["Caption"]
|
||||
|
||||
final_prompts = {}
|
||||
|
||||
final_prompts["Generate"] = (
|
||||
"Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}."
|
||||
)
|
||||
|
||||
final_prompts["RefineA"] = """
|
||||
[EXISTING JSON]: {json_data}
|
||||
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules
|
||||
for modification, consistency, and formatting.
|
||||
"""
|
||||
|
||||
final_prompts["RefineB"] = """
|
||||
[EXISTING JSON]: {json_data}
|
||||
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that
|
||||
incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and
|
||||
consistency.
|
||||
"""
|
||||
|
||||
final_prompts["InspireA"] = """
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose)
|
||||
and then generate a new, detailed JSON object for the scene described in the instructions above. Do not
|
||||
describe the reference image itself; describe the new scene. Follow all of your system rules.
|
||||
"""
|
||||
|
||||
final_prompts["Caption"] = (
|
||||
"Analyze the provided image and generate the detailed JSON object as specified in your instructions."
|
||||
)
|
||||
final_prompts["InspireB"] = final_prompts["Caption"]
|
||||
|
||||
return system_prompts.get(mode, ""), final_prompts.get(mode, "")
|
||||
|
||||
|
||||
def keep(p, k, v):
|
||||
is_none = v is None
|
||||
is_empty_string = isinstance(v, str) and v == ""
|
||||
is_empty_dict = isinstance(v, dict) and not v
|
||||
is_empty_list = isinstance(v, list) and not v
|
||||
is_nan = isinstance(v, float) and math.isnan(v)
|
||||
if is_none or is_empty_string or is_empty_list or is_empty_dict:
|
||||
return False
|
||||
if is_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def validate_json(json_data: dict) -> dict:
|
||||
ia = ImageAnalysis.model_validate_json(json_data, strict=True)
|
||||
return ia.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
def validate_structured_prompt_str(structured_prompt_str: str) -> str:
|
||||
ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True)
|
||||
c = ia.model_dump(exclude_none=True)
|
||||
return json.dumps(c)
|
||||
|
||||
|
||||
def prepare_clean_caption(json_dump: dict) -> str:
|
||||
# filter empty values recursivly (i.e. None, "", {}, [], float("nan"))
|
||||
clean_caption_dict = remap(json_dump, visit=keep)
|
||||
|
||||
scores = {"preference_score": "very high", "aesthetic_score": "very high"}
|
||||
# Set aesthetics scores
|
||||
if "aesthetics" not in clean_caption_dict:
|
||||
clean_caption_dict["aesthetics"] = scores
|
||||
else:
|
||||
clean_caption_dict["aesthetics"].update(scores)
|
||||
|
||||
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
|
||||
clean_caption_str = json.dumps(clean_caption_dict)
|
||||
return clean_caption_str
|
||||
|
||||
|
||||
# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024)
|
||||
# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number).
|
||||
def resize_image_by_num_pixels(
|
||||
image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0
|
||||
) -> Image.Image:
|
||||
if target_ratio != 0.0:
|
||||
ratio = target_ratio
|
||||
else:
|
||||
ratio = image.size[0] / image.size[1]
|
||||
width = int((pixel_number * ratio) ** 0.5)
|
||||
width = width - (width % granularity_val)
|
||||
height = int(pixel_number / width)
|
||||
height = height - (height % granularity_val)
|
||||
return image.resize((width, height))
|
||||
|
||||
|
||||
def infer_with_gemini(
|
||||
client: genai.Client,
|
||||
final_prompt: str,
|
||||
system_prompt: str,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
seed: int = 42,
|
||||
image: Optional[Image.Image] = None,
|
||||
model: str = "gemini-2.5-flash",
|
||||
) -> str:
|
||||
"""
|
||||
Calls Gemini API with the given prompt and returns the raw JSON response.
|
||||
|
||||
Args:
|
||||
final_prompt: The text prompt to send to Gemini
|
||||
system_prompt: The system instruction for Gemini
|
||||
existing_image_path: Optional path to an image file to include
|
||||
model: The Gemini model to use
|
||||
|
||||
Returns:
|
||||
Raw JSON response text from Gemini
|
||||
"""
|
||||
parts = [{"text": final_prompt}]
|
||||
if image:
|
||||
# Save image into bytes
|
||||
image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect
|
||||
less_then = 262144
|
||||
if image.size[0] * image.size[1] > less_then:
|
||||
image = resize_image_by_num_pixels(
|
||||
image, pixel_number=less_then, granularity_val=1, target_ratio=0.0
|
||||
) # 512x512
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
image_bytes = buffer.getvalue()
|
||||
|
||||
img_part = {
|
||||
"inlineData": {
|
||||
"data": image_bytes,
|
||||
"mimeType": "image/jpeg",
|
||||
}
|
||||
}
|
||||
parts.append(img_part)
|
||||
|
||||
contents = [{"role": "user", "parts": parts}]
|
||||
|
||||
generationConfig = {
|
||||
"temperature": temperature,
|
||||
"topP": top_p,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": get_gemini_output_schema(),
|
||||
"system_instruction": system_prompt, # len 5900
|
||||
"thinkingConfig": {"thinkingBudget": 0},
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=generationConfig,
|
||||
)
|
||||
|
||||
if response.candidates[0].finish_reason == "MAX_TOKENS":
|
||||
raise Exception("Max tokens")
|
||||
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
|
||||
def get_default_negative_prompt(existing_json: dict) -> str:
|
||||
negative_prompt = ""
|
||||
style_medium = existing_json.get("style_medium", "").lower()
|
||||
if style_medium in ["photograph", "photography", "photo"]:
|
||||
negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}"""
|
||||
return negative_prompt
|
||||
|
||||
|
||||
def json_promptify(
|
||||
client: genai.Client,
|
||||
model_id: str,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_prompt: Optional[str] = None,
|
||||
existing_json: Optional[str] = None,
|
||||
image: Optional[Image.Image] = None,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
if existing_json:
|
||||
# make sure aesthetic scores are not in the existing json (will be added later)
|
||||
existing_json = json.loads(existing_json)
|
||||
if "aesthetics" in existing_json:
|
||||
existing_json["aesthetics"].pop("aesthetic_score", None)
|
||||
existing_json["aesthetics"].pop("preference_score", None)
|
||||
existing_json = json.dumps(existing_json)
|
||||
|
||||
if not user_prompt:
|
||||
raise ValueError("user_prompt is required if existing_json is provided")
|
||||
|
||||
if image:
|
||||
mode = "RefineB"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
|
||||
|
||||
else:
|
||||
mode = "RefineA"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
|
||||
elif image and user_prompt:
|
||||
mode = "InspireA"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt)
|
||||
elif image and not user_prompt:
|
||||
mode = "Caption"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
else:
|
||||
mode = "Generate"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt)
|
||||
|
||||
json_data = infer_with_gemini(
|
||||
client=client,
|
||||
model=model_id,
|
||||
final_prompt=final_prompt,
|
||||
system_prompt=system_prompt,
|
||||
seed=seed,
|
||||
image=image,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
json_data = validate_json(json_data)
|
||||
clean_caption = prepare_clean_caption(json_data)
|
||||
|
||||
return clean_caption
|
||||
|
||||
|
||||
class BriaFiboGeminiPromptToJson(ModularPipelineBlocks):
|
||||
model_name = "BriaFibo"
|
||||
|
||||
def __init__(self, model_id="gemini-2.5-flash"):
|
||||
super().__init__()
|
||||
api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
|
||||
self.model_id = model_id
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute")
|
||||
prompt_input = InputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
required=False,
|
||||
description="Prompt to use",
|
||||
)
|
||||
image_input = InputParam(
|
||||
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
|
||||
)
|
||||
json_prompt_input = InputParam(
|
||||
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
|
||||
)
|
||||
sampling_top_p_input = InputParam(
|
||||
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0
|
||||
)
|
||||
sampling_temperature_input = InputParam(
|
||||
name="sampling_temperature",
|
||||
type_hint=float,
|
||||
required=False,
|
||||
description="Sampling temperature",
|
||||
default=0.2,
|
||||
)
|
||||
sampling_max_tokens_input = InputParam(
|
||||
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000
|
||||
)
|
||||
return [
|
||||
task_input,
|
||||
prompt_input,
|
||||
image_input,
|
||||
json_prompt_input,
|
||||
sampling_top_p_input,
|
||||
sampling_temperature_input,
|
||||
sampling_max_tokens_input,
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"json_prompt",
|
||||
type_hint=str,
|
||||
description="JSON prompt by the VLM",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
prompt = block_state.prompt
|
||||
image = block_state.image
|
||||
json_prompt = block_state.json_prompt
|
||||
client = genai.Client()
|
||||
json_prompt = json_promptify(
|
||||
client=client,
|
||||
model_id=self.model_id,
|
||||
top_p=block_state.sampling_top_p,
|
||||
temperature=block_state.sampling_temperature,
|
||||
max_tokens=block_state.sampling_max_tokens,
|
||||
user_prompt=prompt,
|
||||
existing_json=json_prompt,
|
||||
image=image,
|
||||
)
|
||||
block_state.json_prompt = json_prompt
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -1,3 +1,13 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -35,20 +45,47 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
...
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import BriaFiboPipeline
|
||||
from diffusers.modular_pipelines import ModularPipeline
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
|
||||
|
||||
pipe = BriaFiboPipeline.from_pretrained(
|
||||
"briaai/FIBO",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
with torch.inference_mode():
|
||||
# 1. Create a prompt to generate an initial image
|
||||
output = vlm_pipe(prompt="a beautiful dog")
|
||||
json_prompt_generate = output.values["json_prompt"]
|
||||
|
||||
# Generate the image from the structured json prompt
|
||||
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
|
||||
results_generate.images[0].save("image_generate.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BriaFiboPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`GaiaTransformer2DModel`]):
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`SmolLM3ForCausalLM`]):
|
||||
transformer (`BriaFiboTransformer2DModel`):
|
||||
The transformer model for 2D diffusion modeling.
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
|
||||
Scheduler to be used with `transformer` to denoise the encoded latents.
|
||||
vae (`AutoencoderKLWan`):
|
||||
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
|
||||
text_encoder (`SmolLM3ForCausalLM`):
|
||||
Text encoder for processing input prompts.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer used for processing the input text prompts for the text_encoder.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
@@ -166,7 +203,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
@@ -181,8 +218,8 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
guidance_scale (`float`):
|
||||
Guidance scale for classifier free guidance.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
@@ -224,7 +261,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if guidance_scale > 1:
|
||||
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
@@ -302,9 +339,6 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
@@ -320,6 +354,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
@@ -366,6 +401,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
@@ -410,34 +446,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
assert height % 16 == 0 and width % 16 == 0
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
noise_scheduler.config.base_image_seq_len,
|
||||
noise_scheduler.config.max_image_seq_len,
|
||||
noise_scheduler.config.base_shift,
|
||||
noise_scheduler.config.max_shift,
|
||||
)
|
||||
|
||||
# Init sigmas and timesteps according to shift size
|
||||
# This changes the scheduler in-place according to the dynamic scheduling
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
noise_scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
return noise_scheduler, timesteps, num_inference_steps, mu
|
||||
|
||||
@staticmethod
|
||||
def create_attention_matrix(attention_mask):
|
||||
def _prepare_attention_mask(attention_mask):
|
||||
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
||||
|
||||
# convert to 0 - keep, -inf ignore
|
||||
@@ -583,7 +592,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
@@ -593,7 +602,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
)
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if guidance_scale > 1:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_layers = [
|
||||
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
|
||||
@@ -611,6 +620,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
|
||||
|
||||
# 5. Prepare latent variables
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
if do_patching:
|
||||
num_channels_latents = int(num_channels_latents / 4)
|
||||
@@ -630,11 +640,11 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
latent_attention_mask = torch.ones(
|
||||
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
if guidance_scale > 1:
|
||||
latent_attention_mask = latent_attention_mask.repeat(2, 1)
|
||||
|
||||
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
|
||||
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
|
||||
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
|
||||
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
|
||||
|
||||
if self._joint_attention_kwargs is None:
|
||||
@@ -648,13 +658,25 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
else:
|
||||
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
|
||||
|
||||
self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler(
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
mu = calculate_shift(
|
||||
seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
|
||||
# Init sigmas and timesteps according to shift size
|
||||
# This changes the scheduler in-place according to the dynamic scheduling
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
noise_scheduler=self.scheduler,
|
||||
image_seq_len=seq_len,
|
||||
device=device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -674,10 +696,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(
|
||||
@@ -697,7 +716,7 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if guidance_scale > 1:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
@@ -736,7 +755,6 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
else:
|
||||
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
|
||||
|
||||
latents = latents.to(dtype=self.vae.dtype)
|
||||
latents = latents.unsqueeze(dim=2)
|
||||
latents = list(torch.unbind(latents, dim=0))
|
||||
latents_device = latents[0].device
|
||||
@@ -780,8 +798,8 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -818,9 +836,3 @@ class BriaFiboPipeline(DiffusionPipeline):
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 3000:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
DiffusionPipeline.to(self, *args, **kwargs)
|
||||
# We use as float32 since wan22 in their repo use it like this
|
||||
self.vae.to(dtype=torch.float32)
|
||||
return self
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import BriaFiboTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -84,48 +84,6 @@ class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaFiboTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -128,7 +128,7 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32)]
|
||||
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height
|
||||
expected_width = width
|
||||
@@ -181,13 +181,18 @@ class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
|
||||
)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
# def test_to_dtype(self):
|
||||
# components = self.get_dummy_components()
|
||||
# pipe = self.pipeline_class(**components)
|
||||
# pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
# self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
# pipe.to(dtype=torch.float16)
|
||||
# model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
# self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
@unittest.skip("")
|
||||
def test_save_load_dduf(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user