mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Bria FIBO pipeline
This commit is contained in:
37
docs/source/en/api/pipelines/bria_fibo.md
Normal file
37
docs/source/en/api/pipelines/bria_fibo.md
Normal file
@@ -0,0 +1,37 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo
|
||||
|
||||
Text-to-image models have mastered imagination - but not control. FIBO changes that.
|
||||
|
||||
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
|
||||
|
||||
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
|
||||
|
||||
## Usage
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaPipeline
|
||||
|
||||
[[autodoc]] BriaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -199,6 +199,7 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaTransformer2DModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
@@ -392,6 +393,8 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"BriaFiboVLMPromptToJson",
|
||||
"BriaFiboGeminiPromptToJson",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
@@ -431,6 +434,7 @@ else:
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
"CLIPImageProjection",
|
||||
@@ -902,6 +906,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
@@ -1104,6 +1109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaPipeline,
|
||||
BriaFiboPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
CLIPImageProjection,
|
||||
|
||||
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
@@ -175,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
|
||||
446
src/diffusers/models/transformers/transformer_bria_fibo.py
Normal file
446
src/diffusers/models/transformers/transformer_bria_fibo.py
Normal file
@@ -0,0 +1,446 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...models.transformers.transformer_flux import FluxTransformerBlock
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a
|
||||
frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta'
|
||||
parameter scales the frequencies. The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class EmbedND(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BriaFiboSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
processor = BriaAttnProcessor()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
...
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
text_encoder_dim: int = 2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BriaFiboSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
caption_projection = [
|
||||
TextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
|
||||
for i in range(self.config.num_layers + self.config.num_single_layers)
|
||||
]
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
text_encoder_layers: list = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if len(img_ids.shape) == 3:
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
new_text_encoder_layers = []
|
||||
for i, text_encoder_layer in enumerate(text_encoder_layers):
|
||||
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
|
||||
new_text_encoder_layers.append(text_encoder_layer)
|
||||
text_encoder_layers = new_text_encoder_layers
|
||||
|
||||
block_id = 0
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -45,6 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboVLMPromptToJson", "BriaFiboGeminiPromptToJson"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
@@ -69,6 +70,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .bria_fibo import BriaFiboGeminiPromptToJson, BriaFiboVLMPromptToJson
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .modular_pipeline import (
|
||||
|
||||
47
src/diffusers/modular_pipelines/bria_fibo/__init__.py
Normal file
47
src/diffusers/modular_pipelines/bria_fibo/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["gemini_prompt_to_json"] = ["BriaFiboGeminiPromptToJson"]
|
||||
_import_structure["fibo_vlm_prompt_to_json"] = ["BriaFiboVLMPromptToJson"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .gemini_prompt_to_json import BriaFiboGeminiPromptToJson
|
||||
from .fibo_vlm_prompt_to_json import BriaFiboVLMPromptToJson
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,377 @@
|
||||
import json
|
||||
import math
|
||||
import textwrap
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from boltons.iterutils import remap
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
|
||||
from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
|
||||
|
||||
|
||||
def parse_aesthetic_score(record: dict) -> str:
|
||||
ae = record["aesthetic_score"]
|
||||
if ae < 5.5:
|
||||
return "very low"
|
||||
elif ae < 6:
|
||||
return "low"
|
||||
elif ae < 7:
|
||||
return "medium"
|
||||
elif ae < 7.6:
|
||||
return "high"
|
||||
else:
|
||||
return "very high"
|
||||
|
||||
|
||||
def parse_pickascore(record: dict) -> str:
|
||||
ps = record["pickascore"]
|
||||
if ps < 0.78:
|
||||
return "very low"
|
||||
elif ps < 0.82:
|
||||
return "low"
|
||||
elif ps < 0.87:
|
||||
return "medium"
|
||||
elif ps < 0.91:
|
||||
return "high"
|
||||
else:
|
||||
return "very high"
|
||||
|
||||
|
||||
def prepare_clean_caption(record: dict) -> str:
|
||||
def keep(p, k, v):
|
||||
is_none = v is None
|
||||
is_empty_string = isinstance(v, str) and v == ""
|
||||
is_empty_dict = isinstance(v, dict) and not v
|
||||
is_empty_list = isinstance(v, list) and not v
|
||||
is_nan = isinstance(v, float) and math.isnan(v)
|
||||
if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
scores = {}
|
||||
if "pickascore" in record:
|
||||
scores["preference_score"] = parse_pickascore(record)
|
||||
if "aesthetic_score" in record:
|
||||
scores["aesthetic_score"] = parse_aesthetic_score(record)
|
||||
|
||||
clean_caption_dict = remap(record, visit=keep)
|
||||
|
||||
# Set aesthetics scores
|
||||
if "aesthetics" not in clean_caption_dict:
|
||||
if len(scores) > 0:
|
||||
clean_caption_dict["aesthetics"] = scores
|
||||
else:
|
||||
clean_caption_dict["aesthetics"].update(scores)
|
||||
|
||||
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
|
||||
clean_caption_str = json.dumps(clean_caption_dict)
|
||||
return clean_caption_str
|
||||
except Exception as ex:
|
||||
print("Error: ", ex)
|
||||
raise ex
|
||||
|
||||
|
||||
def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]:
|
||||
images: List[Image.Image] = []
|
||||
for message in messages:
|
||||
content = message.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "image":
|
||||
continue
|
||||
image_value = item.get("image")
|
||||
if isinstance(image_value, Image.Image):
|
||||
images.append(image_value)
|
||||
else:
|
||||
raise ValueError("Expected PIL.Image for image content in messages.")
|
||||
return images
|
||||
|
||||
|
||||
def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str:
|
||||
if not stop_sequences:
|
||||
return text.strip()
|
||||
cleaned = text
|
||||
for stop in stop_sequences:
|
||||
if not stop:
|
||||
continue
|
||||
index = cleaned.find(stop)
|
||||
if index >= 0:
|
||||
cleaned = cleaned[:index]
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
class TransformersEngine(torch.nn.Module):
|
||||
"""Inference wrapper using Hugging Face transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super(TransformersEngine, self).__init__()
|
||||
default_processor_kwargs: Dict[str, Any] = {
|
||||
"min_pixels": 256 * 28 * 28,
|
||||
"max_pixels": 1024 * 28 * 28,
|
||||
}
|
||||
processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})}
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs)
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model,
|
||||
dtype=torch.bfloat16,
|
||||
**model_kwargs,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
tokenizer_obj = self.processor.tokenizer
|
||||
if tokenizer_obj.pad_token_id is None:
|
||||
tokenizer_obj.pad_token = tokenizer_obj.eos_token
|
||||
self._pad_token_id = tokenizer_obj.pad_token_id
|
||||
eos_token_id = tokenizer_obj.eos_token_id
|
||||
if isinstance(eos_token_id, list) and eos_token_id:
|
||||
self._eos_token_id = eos_token_id
|
||||
elif eos_token_id is not None:
|
||||
self._eos_token_id = [eos_token_id]
|
||||
else:
|
||||
raise ValueError("Tokenizer must define an EOS token for generation.")
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def _to_model_device(self, value: Any) -> Any:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return value
|
||||
target_device = getattr(self.model, "device", None)
|
||||
if target_device is None or target_device.type == "meta":
|
||||
return value
|
||||
if value.device == target_device:
|
||||
return value
|
||||
return value.to(target_device)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
tokenizer = self.processor.tokenizer
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
processor_inputs: Dict[str, Any] = {
|
||||
"text": [prompt_text],
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
images = _collect_images(messages)
|
||||
if images:
|
||||
processor_inputs["images"] = images
|
||||
inputs = self.processor(**processor_inputs)
|
||||
inputs = {key: self._to_model_device(value) for key, value in inputs.items()}
|
||||
|
||||
generation_kwargs: Dict[str, Any] = {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": temperature > 0,
|
||||
"eos_token_id": self._eos_token_id,
|
||||
"pad_token_id": self._pad_token_id,
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
generated_ids = self.model.generate(**inputs, **generation_kwargs)
|
||||
|
||||
input_ids = inputs.get("input_ids")
|
||||
if input_ids is None:
|
||||
raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.")
|
||||
new_token_ids = generated_ids[:, input_ids.shape[-1] :]
|
||||
decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
|
||||
if not decoded:
|
||||
return ""
|
||||
text = decoded[0]
|
||||
stripped_text = _strip_stop_sequences(text, stop)
|
||||
json_prompt = json.loads(stripped_text)
|
||||
return json_prompt
|
||||
|
||||
|
||||
def generate_json_prompt(
|
||||
vlm_processor: AutoModelForCausalLM,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: List[str],
|
||||
image: Optional[Image.Image] = None,
|
||||
prompt: Optional[str] = None,
|
||||
structured_prompt: Optional[str] = None,
|
||||
):
|
||||
if image is None and structured_prompt is None:
|
||||
# only got prompt
|
||||
task = "generate"
|
||||
editing_instructions = None
|
||||
elif image is None and structured_prompt is not None and prompt is not None:
|
||||
# got structured prompt and prompt
|
||||
task = "refine"
|
||||
editing_instructions = prompt
|
||||
elif image is not None and structured_prompt is None and prompt is not None:
|
||||
# got image and prompt
|
||||
task = "refine"
|
||||
editing_instructions = prompt
|
||||
elif image is not None and structured_prompt is None and prompt is None:
|
||||
# only got image
|
||||
task = "inspire"
|
||||
editing_instructions = None
|
||||
else:
|
||||
raise ValueError("Invalid input")
|
||||
|
||||
messages = build_messages(
|
||||
task,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
structured_prompt=structured_prompt,
|
||||
editing_instructions=editing_instructions,
|
||||
)
|
||||
|
||||
generated_prompt = vlm_processor.generate(
|
||||
messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop
|
||||
)
|
||||
cleaned_json_data = prepare_clean_caption(generated_prompt)
|
||||
return cleaned_json_data
|
||||
|
||||
|
||||
def build_messages(
|
||||
task: str,
|
||||
*,
|
||||
image: Optional[Image.Image] = None,
|
||||
refine_image: Optional[Image.Image] = None,
|
||||
prompt: Optional[str] = None,
|
||||
structured_prompt: Optional[str] = None,
|
||||
editing_instructions: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
user_content: List[Dict[str, Any]] = []
|
||||
|
||||
if task == "inspire":
|
||||
user_content.append({"type": "image", "image": image})
|
||||
user_content.append({"type": "text", "text": "<inspire>"})
|
||||
elif task == "generate":
|
||||
text_value = (prompt or "").strip()
|
||||
formatted = f"<generate>\n{text_value}"
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
else: # refine
|
||||
if refine_image is None:
|
||||
base_prompt = (structured_prompt or "").strip()
|
||||
edits = (editing_instructions or "").strip()
|
||||
formatted = textwrap.dedent(
|
||||
f"""<refine> Input: {base_prompt} Editing instructions: {edits}"""
|
||||
).strip()
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
else:
|
||||
user_content.append({"type": "image", "image": refine_image})
|
||||
edits = (editing_instructions or "").strip()
|
||||
formatted = textwrap.dedent(
|
||||
f"""<refine> Editing instructions: {edits}"""
|
||||
).strip()
|
||||
user_content.append({"type": "text", "text": formatted})
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
return messages
|
||||
|
||||
|
||||
class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
|
||||
model_name = "BriaFibo"
|
||||
|
||||
def __init__(self, model_id):
|
||||
super().__init__()
|
||||
self.engine = TransformersEngine(model_id)
|
||||
self.engine.model.to("cuda")
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
prompt_input = InputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
required=False,
|
||||
description="Prompt to use",
|
||||
)
|
||||
image_input = InputParam(
|
||||
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
|
||||
)
|
||||
json_prompt_input = InputParam(
|
||||
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
|
||||
)
|
||||
sampling_top_p_input = InputParam(
|
||||
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9
|
||||
)
|
||||
sampling_temperature_input = InputParam(
|
||||
name="sampling_temperature",
|
||||
type_hint=float,
|
||||
required=False,
|
||||
description="Sampling temperature",
|
||||
default=0.2,
|
||||
)
|
||||
sampling_max_tokens_input = InputParam(
|
||||
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096
|
||||
)
|
||||
return [
|
||||
prompt_input,
|
||||
image_input,
|
||||
json_prompt_input,
|
||||
sampling_top_p_input,
|
||||
sampling_temperature_input,
|
||||
sampling_max_tokens_input,
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"json_prompt",
|
||||
type_hint=str,
|
||||
description="JSON prompt by the VLM",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt = block_state.prompt
|
||||
image = block_state.image
|
||||
json_prompt = block_state.json_prompt
|
||||
block_state.json_prompt = generate_json_prompt(
|
||||
vlm_processor=self.engine,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
structured_prompt=json_prompt,
|
||||
top_p=block_state.sampling_top_p,
|
||||
temperature=block_state.sampling_temperature,
|
||||
max_tokens=block_state.sampling_max_tokens,
|
||||
stop=["<|im_end|>", "<|end_of_text|>"],
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -0,0 +1,804 @@
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import remap
|
||||
from google import genai
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam, PipelineState
|
||||
|
||||
|
||||
class ObjectDescription(BaseModel):
|
||||
description: str = Field(..., description="Short description of the object.")
|
||||
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
|
||||
relationship: str = Field(
|
||||
..., description="Describe the relationship between the object and the other objects in the image."
|
||||
)
|
||||
relative_size: Optional[str] = Field(None, description="E.g., 'small', 'medium', 'large within frame'.")
|
||||
shape_and_color: Optional[str] = Field(None, description="Describe the basic shape and dominant color.")
|
||||
texture: Optional[str] = Field(None, description="E.g., 'smooth', 'rough', 'metallic', 'furry'.")
|
||||
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
|
||||
# If cluster of object
|
||||
number_of_objects: Optional[int] = Field(None, description="The number of objects in the cluster.")
|
||||
# Human-specific fields
|
||||
pose: Optional[str] = Field(None, description="Describe the body position.")
|
||||
expression: Optional[str] = Field(None, description="Describe facial expression.")
|
||||
clothing: Optional[str] = Field(None, description="Describe attire.")
|
||||
action: Optional[str] = Field(None, description="Describe the action of the human.")
|
||||
gender: Optional[str] = Field(None, description="Describe the gender of the human.")
|
||||
skin_tone_and_texture: Optional[str] = Field(None, description="Describe the skin tone and texture.")
|
||||
orientation: Optional[str] = Field(None, description="Describe the orientation of the human.")
|
||||
|
||||
|
||||
class LightingDetails(BaseModel):
|
||||
conditions: str = Field(
|
||||
..., description="E.g., 'bright daylight', 'dim indoor', 'studio lighting', 'golden hour'."
|
||||
)
|
||||
direction: str = Field(..., description="E.g., 'front-lit', 'backlit', 'side-lit from left'.")
|
||||
shadows: Optional[str] = Field(None, description="Describe the presence of shadows.")
|
||||
|
||||
|
||||
class AestheticsDetails(BaseModel):
|
||||
composition: str = Field(..., description="E.g., 'rule of thirds', 'symmetrical', 'centered', 'leading lines'.")
|
||||
color_scheme: str = Field(
|
||||
..., description="E.g., 'monochromatic blue', 'warm complementary colors', 'high contrast'."
|
||||
)
|
||||
mood_atmosphere: str = Field(..., description="E.g., 'serene', 'energetic', 'mysterious', 'joyful'.")
|
||||
|
||||
|
||||
class PhotographicCharacteristicsDetails(BaseModel):
|
||||
depth_of_field: str = Field(..., description="E.g., 'shallow', 'deep', 'bokeh background'.")
|
||||
focus: str = Field(..., description="E.g., 'sharp focus on subject', 'soft focus', 'motion blur'.")
|
||||
camera_angle: str = Field(..., description="E.g., 'eye-level', 'low angle', 'high angle', 'dutch angle'.")
|
||||
lens_focal_length: str = Field(..., description="E.g., 'wide-angle', 'telephoto', 'macro', 'fisheye'.")
|
||||
|
||||
|
||||
class TextRender(BaseModel):
|
||||
text: str = Field(..., description="The text content.")
|
||||
location: str = Field(..., description="E.g., 'center', 'top-left', 'bottom-right foreground'.")
|
||||
size: str = Field(..., description="E.g., 'small', 'medium', 'large within frame'.")
|
||||
color: str = Field(..., description="E.g., 'red', 'blue', 'green'.")
|
||||
font: str = Field(..., description="E.g., 'realistic', 'cartoonish', 'minimalist'.")
|
||||
appearance_details: Optional[str] = Field(None, description="Any other notable visual details.")
|
||||
|
||||
|
||||
class ImageAnalysis(BaseModel):
|
||||
short_description: str = Field(..., description="A concise summary of the image content, 200 words maximum.")
|
||||
objects: List[ObjectDescription] = Field(..., description="List of prominent foreground/midground objects.")
|
||||
background_setting: str = Field(
|
||||
...,
|
||||
description="Describe the overall environment, setting, or background, including any notable background elements.",
|
||||
)
|
||||
lighting: LightingDetails = Field(..., description="Details about the lighting.")
|
||||
aesthetics: AestheticsDetails = Field(..., description="Details about the image aesthetics.")
|
||||
photographic_characteristics: Optional[PhotographicCharacteristicsDetails] = Field(
|
||||
None, description="Details about photographic characteristics."
|
||||
)
|
||||
style_medium: Optional[str] = Field(None, description="Identify the artistic style or medium.")
|
||||
text_render: Optional[List[TextRender]] = Field(None, description="List of text renders in the image.")
|
||||
context: str = Field(..., description="Provide any additional context that helps understand the image better.")
|
||||
artistic_style: Optional[str] = Field(
|
||||
None, description="describe specific artistic characteristics, 3 words maximum."
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_output_schema() -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"short_description": {"type": "STRING"},
|
||||
"objects": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"description": {"type": "STRING"},
|
||||
"location": {"type": "STRING"},
|
||||
"relationship": {"type": "STRING"},
|
||||
"relative_size": {"type": "STRING"},
|
||||
"shape_and_color": {"type": "STRING"},
|
||||
"texture": {"nullable": True, "type": "STRING"},
|
||||
"appearance_details": {"nullable": True, "type": "STRING"},
|
||||
"number_of_objects": {"nullable": True, "type": "INTEGER"},
|
||||
"pose": {"nullable": True, "type": "STRING"},
|
||||
"expression": {"nullable": True, "type": "STRING"},
|
||||
"clothing": {"nullable": True, "type": "STRING"},
|
||||
"action": {"nullable": True, "type": "STRING"},
|
||||
"gender": {"nullable": True, "type": "STRING"},
|
||||
"skin_tone_and_texture": {"nullable": True, "type": "STRING"},
|
||||
"orientation": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"description",
|
||||
"location",
|
||||
"relationship",
|
||||
"relative_size",
|
||||
"shape_and_color",
|
||||
"texture",
|
||||
"appearance_details",
|
||||
"number_of_objects",
|
||||
"pose",
|
||||
"expression",
|
||||
"clothing",
|
||||
"action",
|
||||
"gender",
|
||||
"skin_tone_and_texture",
|
||||
"orientation",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"type": "ARRAY",
|
||||
},
|
||||
"background_setting": {"type": "STRING"},
|
||||
"lighting": {
|
||||
"properties": {
|
||||
"conditions": {"type": "STRING"},
|
||||
"direction": {"type": "STRING"},
|
||||
"shadows": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": ["conditions", "direction", "shadows"],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"aesthetics": {
|
||||
"properties": {
|
||||
"composition": {"type": "STRING"},
|
||||
"color_scheme": {"type": "STRING"},
|
||||
"mood_atmosphere": {"type": "STRING"},
|
||||
},
|
||||
"required": ["composition", "color_scheme", "mood_atmosphere"],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"photographic_characteristics": {
|
||||
"nullable": True,
|
||||
"properties": {
|
||||
"depth_of_field": {"type": "STRING"},
|
||||
"focus": {"type": "STRING"},
|
||||
"camera_angle": {"type": "STRING"},
|
||||
"lens_focal_length": {"type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"depth_of_field",
|
||||
"focus",
|
||||
"camera_angle",
|
||||
"lens_focal_length",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"style_medium": {"type": "STRING"},
|
||||
"text_render": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"text": {"type": "STRING"},
|
||||
"location": {"type": "STRING"},
|
||||
"size": {"type": "STRING"},
|
||||
"color": {"type": "STRING"},
|
||||
"font": {"type": "STRING"},
|
||||
"appearance_details": {"nullable": True, "type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"text",
|
||||
"location",
|
||||
"size",
|
||||
"color",
|
||||
"font",
|
||||
"appearance_details",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
},
|
||||
"type": "ARRAY",
|
||||
},
|
||||
"context": {"type": "STRING"},
|
||||
"artistic_style": {"type": "STRING"},
|
||||
},
|
||||
"required": [
|
||||
"short_description",
|
||||
"objects",
|
||||
"background_setting",
|
||||
"lighting",
|
||||
"aesthetics",
|
||||
"photographic_characteristics",
|
||||
"style_medium",
|
||||
"text_render",
|
||||
"context",
|
||||
"artistic_style",
|
||||
],
|
||||
"type": "OBJECT",
|
||||
}
|
||||
|
||||
|
||||
json_schema_full = """1. `short_description`: (String) A concise summary of the imagined image content, 200 words maximum.
|
||||
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects. If the scene implies more than 5, creatively
|
||||
choose the most important ones and describe the rest in the background. For each object, include:
|
||||
* `description`: (String) A detailed description of the imagined object, 100 words maximum.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `relative_size`: (String) E.g., "small", "medium", "large within frame". (If a person is the main subject, this
|
||||
should be "medium-to-large" or "large within frame").
|
||||
* `shape_and_color`: (String) Describe the basic shape and dominant color.
|
||||
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
|
||||
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
|
||||
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
|
||||
* If the object is a human or a human-like object, include the following:
|
||||
* `pose`: (String) Describe the body position.
|
||||
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
|
||||
"surprised", "calm".
|
||||
* `clothing`: (String) Describe attire.
|
||||
* `action`: (String) Describe the action of the human.
|
||||
* `gender`: (String) Describe the gender of the human.
|
||||
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
|
||||
* If the object is a cluster of objects, include the following:
|
||||
* `number_of_objects`: (Integer) The number of objects in the cluster.
|
||||
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
|
||||
background elements that are not part of the `objects` section.
|
||||
4. `lighting`: (Object)
|
||||
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
|
||||
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
|
||||
* `shadows`: (String) Describe the presence and quality of shadows, e.g., "long, soft shadows", "sharp, defined
|
||||
shadows", "minimal shadows".
|
||||
5. `aesthetics`: (Object)
|
||||
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines". If people are the
|
||||
main subject, specify the shot type, e.g., "medium shot", "close-up", "portrait composition".
|
||||
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
|
||||
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
|
||||
6. `photographic_characteristics`: (Object)
|
||||
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
|
||||
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
|
||||
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
|
||||
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye". (If the main subject is a
|
||||
person, prefer "standard lens (e.g., 35mm-50mm)" or "portrait lens (e.g., 50mm-85mm)" to ensure they are framed
|
||||
more closely. Avoid "wide-angle" for people unless specified).
|
||||
7. `style_medium`: (String) Identify the artistic style or medium based on the user's prompt or creative
|
||||
interpretation (e.g., "photograph", "oil painting", "watercolor", "3D render", "digital illustration", "pencil
|
||||
sketch").
|
||||
8. `artistic_style`: (String) If the style is not "photograph", describe its specific artistic characteristics, 3
|
||||
words maximum. (e.g., "impressionistic, vibrant, textured" for an oil painting).
|
||||
9. `context`: (String) Provide a general description of the type of image this would be. For example: "This is a
|
||||
concept for a high-fashion editorial photograph intended for a magazine spread," or "This describes a piece of
|
||||
concept art for a fantasy video game."
|
||||
10. `text_render`: (Array of Objects) By default, this array should be empty (`[]`). Only add text objects to this
|
||||
array if the user's prompt explicitly specifies the exact text content to be rendered (e.g., user asks for "a
|
||||
poster with the title 'Cosmic Dream'"). Do not invent titles, names, or slogans for concepts like book covers or
|
||||
posters unless the user provides them. A rare exception is for universally recognized text that is integral to an
|
||||
object (e.g., the word 'STOP' on a 'stop sign'). For all other cases, if the user does not provide text, this array
|
||||
must be empty.
|
||||
* `text`: (String) The exact text content provided by the user. NEVER use generic placeholders.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `size`: (String) E.g., "medium", "large", "large within frame".
|
||||
* `color`: (String) E.g., "red", "blue", "green".
|
||||
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist", "serif typeface".
|
||||
* `appearance_details`: (String) Any other notable visual details."""
|
||||
|
||||
|
||||
@cache
|
||||
def get_instructions(mode: str) -> Tuple[str, str]:
|
||||
system_prompts = {}
|
||||
|
||||
system_prompts["Caption"] = """
|
||||
You are a meticulous and perceptive Visual Art Director working for a leading Generative AI company. Your expertise
|
||||
lies in analyzing images and extracting detailed, structured information. Your primary task is to analyze provided
|
||||
images and generate a comprehensive JSON object describing them. Adhere strictly to the following structure and
|
||||
guidelines: The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g.,
|
||||
no "Here is the JSON:", no explanations, no apologies). IMPORTANT: When describing human body parts, positions, or
|
||||
actions, always describe them from the PERSON'S OWN PERSPECTIVE, not from the observer's viewpoint. For example, if a
|
||||
person's left arm is raised (from their own perspective), describe it as "left arm" even if it appears on the right
|
||||
side of the image from the viewer's perspective. The JSON object must contain the following keys precisely:
|
||||
1. `short_description`: (String) A concise summary of the image content, 200 words maximum.
|
||||
2. `objects`: (Array of Objects) List a maximum of 5 prominent objects if there are more than 5, list them in the
|
||||
background. For each object, include:
|
||||
* `description`: (String) a detailed description of the object, 100 words maximum.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `relative_size`: (String) E.g., "small", "medium", "large within frame".
|
||||
* `shape_and_color`: (String) Describe the basic shape and dominant color.
|
||||
* `texture`: (String) E.g., "smooth", "rough", "metallic", "furry".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
* `relationship`: (String) Describe the relationship between the object and the other objects in the image.
|
||||
* `orientation`: (String) Describe the orientation or positioning of the object, e.g., "upright", "tilted 45
|
||||
degrees", "horizontal", "vertical", "facing left", "facing right", "upside down", "lying on its side".
|
||||
if the object is a human or a human-like object, include the following:
|
||||
* `pose`: (String) Describe the body position.
|
||||
* `expression`: (String) Describe facial expression and emotion. E.g., "winking", "joyful", "serious",
|
||||
"surprised", "calm".
|
||||
* `clothing`: (String) Describe attire.
|
||||
* `action`: (String) Describe the action of the human.
|
||||
* `gender`: (String) Describe the gender of the human.
|
||||
* `skin_tone_and_texture`: (String) Describe the skin tone and texture.
|
||||
if the object is a cluster of objects, include the following:
|
||||
* `number_of_objects`: (Integer) The number of objects in the cluster.
|
||||
3. `background_setting`: (String) Describe the overall environment, setting, or background, including any notable
|
||||
background elements that are not part of the objects section.
|
||||
4. `lighting`: (Object)
|
||||
* `conditions`: (String) E.g., "bright daylight", "dim indoor", "studio lighting", "golden hour".
|
||||
* `direction`: (String) E.g., "front-lit", "backlit", "side-lit from left".
|
||||
* `shadows`: (String) Describe the presence of shadows.
|
||||
5. `aesthetics`: (Object)
|
||||
* `composition`: (String) E.g., "rule of thirds", "symmetrical", "centered", "leading lines".
|
||||
* `color_scheme`: (String) E.g., "monochromatic blue", "warm complementary colors", "high contrast".
|
||||
* `mood_atmosphere`: (String) E.g., "serene", "energetic", "mysterious", "joyful".
|
||||
6. `photographic_characteristics`: (Object)
|
||||
* `depth_of_field`: (String) E.g., "shallow", "deep", "bokeh background".
|
||||
* `focus`: (String) E.g., "sharp focus on subject", "soft focus", "motion blur".
|
||||
* `camera_angle`: (String) E.g., "eye-level", "low angle", "high angle", "dutch angle".
|
||||
* `lens_focal_length`: (String) E.g., "wide-angle", "telephoto", "macro", "fisheye".
|
||||
7. `style_medium`: (String) Identify the artistic style or medium (e.g., "photograph", "oil painting", "watercolor",
|
||||
"3D render", "digital illustration", "pencil sketch") If the style is not "photograph", but artistic, please
|
||||
describe the specific artistic characteristics under 'artistic_style', 50 words maximum.
|
||||
8. `artistic_style`: (String) describe specific artistic characteristics, 3 words maximum.
|
||||
9. `context`: (String) Provide any additional context that helps understand the image better. This should include a
|
||||
general description of the type of image (e.g., Fashion Photography, Product Shot, Magazine Cover, Nature
|
||||
Photography, Art Piece, etc.), as well as any other relevant contextual information that situates the image within a
|
||||
broader category or intended use. For example: "This is a high-fashion editorial photograph intended for a magazine
|
||||
spread"
|
||||
10. `text_render`: (Array of Objects) List of a maximum of 5 most prominent text renders in the image. For each text
|
||||
render, include:
|
||||
* `text`: (String) The text content.
|
||||
* `location`: (String) E.g., "center", "top-left", "bottom-right foreground".
|
||||
* `size`: (String) E.g., "small", "medium", "large within frame".
|
||||
* `color`: (String) E.g., "red", "blue", "green".
|
||||
* `font`: (String) E.g., "realistic", "cartoonish", "minimalist".
|
||||
* `appearance_details`: (String) Any other notable visual details.
|
||||
Ensure the information within the JSON is accurate, detailed where specified, and avoids redundancy between fields.
|
||||
"""
|
||||
|
||||
system_prompts[
|
||||
"Generate"
|
||||
] = f"""You are a visionary and creative Visual Art Director at a leading Generative AI company.
|
||||
|
||||
Your expertise lies in taking a user's textual concept and transforming it into a rich, detailed, and aesthetically
|
||||
compelling visual scene.
|
||||
|
||||
Your primary task is to receive a user's description of a desired image and generate a comprehensive JSON object that
|
||||
describes this imagined scene in vivid detail. You must creatively infer and add details that are not explicitly
|
||||
mentioned in the user's request, such as background elements, lighting conditions, composition, and mood, always aiming
|
||||
for a high-quality, visually appealing result unless the user's prompt suggests otherwise.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
The output MUST be ONLY a valid JSON object. Do not include any text before or after the JSON object (e.g., no "Here is
|
||||
the JSON:", no explanations, no apologies).
|
||||
|
||||
IMPORTANT: When describing human body parts, positions, or actions, always describe them from the PERSON'S OWN
|
||||
PERSPECTIVE, not from the observer's viewpoint. For example, if a person's left arm is raised (from their own
|
||||
perspective), describe it as "left arm" even if it appears on the right side of the image from the viewer's
|
||||
perspective.
|
||||
|
||||
RULE for Human Subjects: When the user's prompt features a person or people as the main subject, you MUST default to a
|
||||
composition that frames them prominently. Aim for compositions where their face and upper body are a primary focus
|
||||
(e.g., 'medium shot', 'close-up'). Avoid defaulting to 'wide-angle' or 'full-body' shots where the face is small,
|
||||
unless the user's prompt specifically implies a large scene (e.g., "a person standing on a mountain").
|
||||
|
||||
Unless the user's prompt explicitly requests a different style (e.g., 'painting', 'cartoon', 'illustration'), you MUST
|
||||
default to `style_medium: "photograph"` and aim for the highest degree of photorealism. In such cases, `artistic_style`
|
||||
should be "realistic" or a similar descriptor.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}
|
||||
|
||||
Ensure the information within the JSON is detailed, creative, internally consistent, and avoids redundancy between
|
||||
fields."""
|
||||
|
||||
system_prompts[
|
||||
"RefineA"
|
||||
] = f"""You are a Meticulous Visual Editor and Senior Art Director at a leading Generative AI company.
|
||||
|
||||
Your expertise is in refining and modifying existing visual concepts based on precise feedback.
|
||||
|
||||
Your primary task is to receive an existing JSON object that describes a visual scene, along with a textual instruction
|
||||
for how to change it. You must then generate a new, updated JSON object that perfectly incorporates the requested
|
||||
changes.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Input:** You will receive two pieces of information: an existing JSON object and a textual instruction.
|
||||
2. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
|
||||
before or after the JSON object.
|
||||
3. **Modification Logic:**
|
||||
* Carefully parse the user's textual instruction to understand the desired changes.
|
||||
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
|
||||
* All other fields not relevant to the change must be copied exactly from the original JSON. Do not alter or omit
|
||||
them.
|
||||
4. **Holistic Consistency (IMPORTANT):** Changes in one field must be logically reflected in others. For example:
|
||||
* If the instruction is to "change the background to a snowy forest," you must update the `background_setting`
|
||||
field, and also update the `short_description` to mention the new setting. The `mood_atmosphere` might also need
|
||||
to change to "serene" or "wintry."
|
||||
* If the instruction is to "add the text 'WINTER SALE' at the top," you must add a new entry to the `text_render`
|
||||
array.
|
||||
* If the instruction is to "make the person smile," you must update the `expression` field for that object and
|
||||
potentially update the overall `mood_atmosphere`.
|
||||
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts[
|
||||
"RefineB"
|
||||
] = f"""You are an advanced Multimodal Visual Specialist at a leading Generative AI company.
|
||||
|
||||
Your unique expertise is in analyzing and editing visual concepts by processing an image, its corresponding JSON
|
||||
metadata, and textual feedback simultaneously.
|
||||
|
||||
Your primary task is to receive three inputs: an existing image, its descriptive JSON object, and a textual instruction
|
||||
for a modification. You must use the image as the primary source of truth to understand the context of the requested
|
||||
change and then generate a new, updated JSON object that accurately reflects that change.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Inputs:** You will receive an image, an existing JSON object, and a textual instruction.
|
||||
2. **Visual Grounding (IMPORTANT):** The provided image is the ground truth. Use it to visually verify the contents of
|
||||
the scene and to understand the context of the user's edit instruction. For example, if the instruction is "make
|
||||
the car blue," visually locate the car in the image to inform your edits to the JSON.
|
||||
3. **Output:** Your output MUST be ONLY a single, valid JSON object in the specified schema. Do not include any text
|
||||
before or after the JSON object.
|
||||
4. **Modification Logic:**
|
||||
* Analyze the user's textual instruction in the context of what you see in the image.
|
||||
* Modify ONLY the fields in the JSON that are directly or logically affected by the instruction.
|
||||
* All other fields not relevant to the change must be copied exactly from the original JSON.
|
||||
5. **Holistic Consistency:** Changes must be reflected logically across the JSON, consistent with a potential visual
|
||||
change to the image. For instance, changing the lighting from 'daylight' to 'golden hour' should not only update
|
||||
the `lighting` object but also the `mood_atmosphere`, `shadows`, and the `short_description`.
|
||||
6. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts[
|
||||
"InspireA"
|
||||
] = f"""You are a highly skilled Creative Director for Visual Adaptation at a leading Generative AI company.
|
||||
|
||||
Your expertise lies in using an existing image as a visual reference to create entirely new scenes. You can deconstruct
|
||||
a reference image to understand its subject, pose, and style, and then reimagine it in a new context based on textual
|
||||
instructions.
|
||||
|
||||
Your primary task is to receive a reference image and a textual instruction. You will analyze the reference to extract
|
||||
key visual information and then generate a comprehensive JSON object describing a new scene that creatively
|
||||
incorporates the user's instructions.
|
||||
|
||||
Adhere strictly to the following structure and guidelines:
|
||||
|
||||
1. **Inputs:** You will receive a reference image and a textual instruction. You will NOT receive a starting JSON.
|
||||
2. **Core Logic (Analyze and Synthesize):**
|
||||
* **Analyze:** First, deeply analyze the provided reference image. Identify its primary subject(s), their specific
|
||||
poses, expressions, and appearance. Also note the overall composition, lighting style, and artistic medium.
|
||||
* **Synthesize:** Next, interpret the textual instruction to understand what elements to keep from the reference
|
||||
and what to change. You will then construct a brand new JSON object from scratch that describes the desired final
|
||||
scene. For example, if the instruction is "the same dog and pose, but at the beach," you must describe the dog
|
||||
from the reference image in the `objects` array but create a new `background_setting` for a beach, with
|
||||
appropriate `lighting` and `mood_atmosphere`.
|
||||
3. **Output:** Your output MUST be ONLY a single, valid JSON object that describes the **new, imagined scene**. Do not
|
||||
describe the original reference image.
|
||||
4. **Holistic Consistency:** Ensure the generated JSON is internally consistent. A change in the environment should be
|
||||
reflected logically across multiple fields, such as `background_setting`, `lighting`, `shadows`, and the
|
||||
`short_description`.
|
||||
5. **Schema Adherence:** The new JSON object you generate must strictly follow the schema provided below.
|
||||
|
||||
The JSON object must contain the following keys precisely:
|
||||
|
||||
{json_schema_full}"""
|
||||
|
||||
system_prompts["InspireB"] = system_prompts["Caption"]
|
||||
|
||||
final_prompts = {}
|
||||
|
||||
final_prompts["Generate"] = (
|
||||
"Generate a detailed JSON object, adhering to the expected schema, for an imagined scene based on the following request: {user_prompt}."
|
||||
)
|
||||
|
||||
final_prompts["RefineA"] = """
|
||||
[EXISTING JSON]: {json_data}
|
||||
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Generate the new, updated JSON object that incorporates the edit instructions. Follow all system rules
|
||||
for modification, consistency, and formatting.
|
||||
"""
|
||||
|
||||
final_prompts["RefineB"] = """
|
||||
[EXISTING JSON]: {json_data}
|
||||
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Analyze the provided image and its contextual JSON. Then, generate the new, updated JSON object that
|
||||
incorporates the edit instructions. Follow all your system rules for visual analysis, modification, and
|
||||
consistency.
|
||||
"""
|
||||
|
||||
final_prompts["InspireA"] = """
|
||||
[EDIT INSTRUCTIONS]: {user_prompt}
|
||||
|
||||
[TASK]: Use the provided image as a visual reference only. Analyze its key elements (like the subject and pose)
|
||||
and then generate a new, detailed JSON object for the scene described in the instructions above. Do not
|
||||
describe the reference image itself; describe the new scene. Follow all of your system rules.
|
||||
"""
|
||||
|
||||
final_prompts["Caption"] = (
|
||||
"Analyze the provided image and generate the detailed JSON object as specified in your instructions."
|
||||
)
|
||||
final_prompts["InspireB"] = final_prompts["Caption"]
|
||||
|
||||
return system_prompts.get(mode, ""), final_prompts.get(mode, "")
|
||||
|
||||
|
||||
def keep(p, k, v):
|
||||
is_none = v is None
|
||||
is_empty_string = isinstance(v, str) and v == ""
|
||||
is_empty_dict = isinstance(v, dict) and not v
|
||||
is_empty_list = isinstance(v, list) and not v
|
||||
is_nan = isinstance(v, float) and math.isnan(v)
|
||||
if is_none or is_empty_string or is_empty_list or is_empty_dict:
|
||||
return False
|
||||
if is_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def validate_json(json_data: dict) -> dict:
|
||||
ia = ImageAnalysis.model_validate_json(json_data, strict=True)
|
||||
return ia.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
def validate_structured_prompt_str(structured_prompt_str: str) -> str:
|
||||
ia = ImageAnalysis.model_validate_json(structured_prompt_str, strict=True)
|
||||
c = ia.model_dump(exclude_none=True)
|
||||
return json.dumps(c)
|
||||
|
||||
|
||||
def prepare_clean_caption(json_dump: dict) -> str:
|
||||
# filter empty values recursivly (i.e. None, "", {}, [], float("nan"))
|
||||
clean_caption_dict = remap(json_dump, visit=keep)
|
||||
|
||||
scores = {"preference_score": "very high", "aesthetic_score": "very high"}
|
||||
# Set aesthetics scores
|
||||
if "aesthetics" not in clean_caption_dict:
|
||||
clean_caption_dict["aesthetics"] = scores
|
||||
else:
|
||||
clean_caption_dict["aesthetics"].update(scores)
|
||||
|
||||
# Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
|
||||
clean_caption_str = json.dumps(clean_caption_dict)
|
||||
return clean_caption_str
|
||||
|
||||
|
||||
# resize an input image to have a specific number of pixels (1,048,576 or 1024×1024)
|
||||
# while maintaining a certain aspect ratio and granularity (output width and height must be multiples of this number).
|
||||
def resize_image_by_num_pixels(
|
||||
image: Image.Image, pixel_number: int = 1048576, granularity_val: int = 64, target_ratio: float = 0.0
|
||||
) -> Image.Image:
|
||||
if target_ratio != 0.0:
|
||||
ratio = target_ratio
|
||||
else:
|
||||
ratio = image.size[0] / image.size[1]
|
||||
width = int((pixel_number * ratio) ** 0.5)
|
||||
width = width - (width % granularity_val)
|
||||
height = int(pixel_number / width)
|
||||
height = height - (height % granularity_val)
|
||||
return image.resize((width, height))
|
||||
|
||||
|
||||
def infer_with_gemini(
|
||||
client: genai.Client,
|
||||
final_prompt: str,
|
||||
system_prompt: str,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
seed: int = 42,
|
||||
image: Optional[Image.Image] = None,
|
||||
model: str = "gemini-2.5-flash",
|
||||
) -> str:
|
||||
"""
|
||||
Calls Gemini API with the given prompt and returns the raw JSON response.
|
||||
|
||||
Args:
|
||||
final_prompt: The text prompt to send to Gemini
|
||||
system_prompt: The system instruction for Gemini
|
||||
existing_image_path: Optional path to an image file to include
|
||||
model: The Gemini model to use
|
||||
|
||||
Returns:
|
||||
Raw JSON response text from Gemini
|
||||
"""
|
||||
parts = [{"text": final_prompt}]
|
||||
if image:
|
||||
# Save image into bytes
|
||||
image = image.convert("RGB") # the model can't produce rgba so sending them as input has no effect
|
||||
less_then = 262144
|
||||
if image.size[0] * image.size[1] > less_then:
|
||||
image = resize_image_by_num_pixels(
|
||||
image, pixel_number=less_then, granularity_val=1, target_ratio=0.0
|
||||
) # 512x512
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
image_bytes = buffer.getvalue()
|
||||
|
||||
img_part = {
|
||||
"inlineData": {
|
||||
"data": image_bytes,
|
||||
"mimeType": "image/jpeg",
|
||||
}
|
||||
}
|
||||
parts.append(img_part)
|
||||
|
||||
contents = [{"role": "user", "parts": parts}]
|
||||
|
||||
generationConfig = {
|
||||
"temperature": temperature,
|
||||
"topP": top_p,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": get_gemini_output_schema(),
|
||||
"system_instruction": system_prompt, # len 5900
|
||||
"thinkingConfig": {"thinkingBudget": 0},
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=generationConfig,
|
||||
)
|
||||
|
||||
if response.candidates[0].finish_reason == "MAX_TOKENS":
|
||||
raise Exception("Max tokens")
|
||||
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
|
||||
def get_default_negative_prompt(existing_json: dict) -> str:
|
||||
negative_prompt = ""
|
||||
style_medium = existing_json.get("style_medium", "").lower()
|
||||
if style_medium in ["photograph", "photography", "photo"]:
|
||||
negative_prompt = """{'style_medium': 'digital illustration', 'artistic_style': 'non-realistic'}"""
|
||||
return negative_prompt
|
||||
|
||||
|
||||
def json_promptify(
|
||||
client: genai.Client,
|
||||
model_id: str,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_prompt: Optional[str] = None,
|
||||
existing_json: Optional[str] = None,
|
||||
image: Optional[Image.Image] = None,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
if existing_json:
|
||||
# make sure aesthetic scores are not in the existing json (will be added later)
|
||||
existing_json = json.loads(existing_json)
|
||||
if "aesthetics" in existing_json:
|
||||
existing_json["aesthetics"].pop("aesthetic_score", None)
|
||||
existing_json["aesthetics"].pop("preference_score", None)
|
||||
existing_json = json.dumps(existing_json)
|
||||
|
||||
if not user_prompt:
|
||||
raise ValueError("user_prompt is required if existing_json is provided")
|
||||
|
||||
if image:
|
||||
mode = "RefineB"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
|
||||
|
||||
else:
|
||||
mode = "RefineA"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt, json_data=existing_json)
|
||||
elif image and user_prompt:
|
||||
mode = "InspireA"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt)
|
||||
elif image and not user_prompt:
|
||||
mode = "Caption"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
else:
|
||||
mode = "Generate"
|
||||
system_prompt, final_prompt = get_instructions(mode)
|
||||
final_prompt = final_prompt.format(user_prompt=user_prompt)
|
||||
|
||||
json_data = infer_with_gemini(
|
||||
client=client,
|
||||
model=model_id,
|
||||
final_prompt=final_prompt,
|
||||
system_prompt=system_prompt,
|
||||
seed=seed,
|
||||
image=image,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
json_data = validate_json(json_data)
|
||||
clean_caption = prepare_clean_caption(json_data)
|
||||
|
||||
return clean_caption
|
||||
|
||||
|
||||
class BriaFiboGeminiPromptToJson(ModularPipelineBlocks):
|
||||
model_name = "BriaFibo"
|
||||
|
||||
def __init__(self, model_id="gemini-2.5-flash"):
|
||||
super().__init__()
|
||||
api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
|
||||
self.model_id = model_id
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
task_input = InputParam("task", type_hint=str, required=False, description="VLM Task to execute")
|
||||
prompt_input = InputParam(
|
||||
"prompt",
|
||||
type_hint=str,
|
||||
required=False,
|
||||
description="Prompt to use",
|
||||
)
|
||||
image_input = InputParam(
|
||||
name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
|
||||
)
|
||||
json_prompt_input = InputParam(
|
||||
name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
|
||||
)
|
||||
sampling_top_p_input = InputParam(
|
||||
name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=1.0
|
||||
)
|
||||
sampling_temperature_input = InputParam(
|
||||
name="sampling_temperature",
|
||||
type_hint=float,
|
||||
required=False,
|
||||
description="Sampling temperature",
|
||||
default=0.2,
|
||||
)
|
||||
sampling_max_tokens_input = InputParam(
|
||||
name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=3000
|
||||
)
|
||||
return [
|
||||
task_input,
|
||||
prompt_input,
|
||||
image_input,
|
||||
json_prompt_input,
|
||||
sampling_top_p_input,
|
||||
sampling_temperature_input,
|
||||
sampling_max_tokens_input,
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"json_prompt",
|
||||
type_hint=str,
|
||||
description="JSON prompt by the VLM",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
prompt = block_state.prompt
|
||||
image = block_state.image
|
||||
json_prompt = block_state.json_prompt
|
||||
client = genai.Client()
|
||||
json_prompt = json_promptify(
|
||||
client=client,
|
||||
model_id=self.model_id,
|
||||
top_p=block_state.sampling_top_p,
|
||||
temperature=block_state.sampling_temperature,
|
||||
max_tokens=block_state.sampling_max_tokens,
|
||||
user_prompt=prompt,
|
||||
existing_json=json_prompt,
|
||||
image=image,
|
||||
)
|
||||
block_state.json_prompt = json_prompt
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -128,6 +128,7 @@ else:
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipelin
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
|
||||
48
src/diffusers/pipelines/bria_fibo/__init__.py
Normal file
48
src/diffusers/pipelines/bria_fibo/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_bria_fibo import BriaFiboPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
826
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Normal file
826
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Normal file
@@ -0,0 +1,826 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
|
||||
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
|
||||
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
class BriaFiboPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`GaiaTransformer2DModel`]):
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`SmolLM3ForCausalLM`]):
|
||||
tokenizer (`AutoTokenizer`):
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: BriaFiboTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKLWan,
|
||||
text_encoder: SmolLM3ForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.default_sample_size = 64
|
||||
|
||||
def get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if not prompt:
|
||||
raise ValueError("`prompt` must be a non-empty string or list of strings.")
|
||||
|
||||
batch_size = len(prompt)
|
||||
bot_token_id = 128000
|
||||
|
||||
text_encoder_device = device if device is not None else torch.device("cpu")
|
||||
if not isinstance(text_encoder_device, torch.device):
|
||||
text_encoder_device = torch.device(text_encoder_device)
|
||||
|
||||
if all(p == "" for p in prompt):
|
||||
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
tokenized = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tokenized.input_ids.to(text_encoder_device)
|
||||
attention_mask = tokenized.attention_mask.to(text_encoder_device)
|
||||
|
||||
if any(p == "" for p in prompt):
|
||||
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
|
||||
input_ids[empty_rows] = bot_token_id
|
||||
attention_mask[empty_rows] = 1
|
||||
|
||||
encoder_outputs = self.text_encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_outputs.hidden_states
|
||||
|
||||
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
hidden_states = tuple(
|
||||
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
|
||||
|
||||
return prompt_embeds, hidden_states, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
|
||||
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
|
||||
batch_size, seq_len, dim = prompt_embeds.shape
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if max_tokens < seq_len:
|
||||
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
|
||||
|
||||
if max_tokens > seq_len:
|
||||
pad_length = max_tokens - seq_len
|
||||
padding = torch.zeros(
|
||||
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
|
||||
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 3000,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
prompt_attention_mask = None
|
||||
negative_prompt_attention_mask = None
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
# Pad to longest
|
||||
if prompt_attention_mask is not None:
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
if negative_prompt_attention_mask is not None:
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
|
||||
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
|
||||
)
|
||||
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
|
||||
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
|
||||
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
|
||||
else:
|
||||
max_tokens = prompt_embeds.shape[1]
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = None
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels)
|
||||
latents = latents.permute(0, 3, 1, 2)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.permute(0, 2, 3, 1)
|
||||
latents = latents.reshape(batch_size, height * width, num_channels_latents)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
do_patching=False,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if do_patching:
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
else:
|
||||
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
assert height % 16 == 0 and width % 16 == 0
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
noise_scheduler.config.base_image_seq_len,
|
||||
noise_scheduler.config.max_image_seq_len,
|
||||
noise_scheduler.config.base_shift,
|
||||
noise_scheduler.config.max_shift,
|
||||
)
|
||||
|
||||
# Init sigmas and timesteps according to shift size
|
||||
# This changes the scheduler in-place according to the dynamic scheduling
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
noise_scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
return noise_scheduler, timesteps, num_inference_steps, mu
|
||||
|
||||
@staticmethod
|
||||
def create_attention_matrix(attention_mask):
|
||||
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
||||
|
||||
# convert to 0 - keep, -inf ignore
|
||||
attention_matrix = torch.where(
|
||||
attention_matrix == 1, 0.0, -torch.inf
|
||||
) # Apply -inf to ignored tokens for nulling softmax score
|
||||
return attention_matrix
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 3000,
|
||||
do_patching=False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_layers = [
|
||||
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
|
||||
]
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
|
||||
self.transformer.single_transformer_blocks
|
||||
)
|
||||
if len(prompt_layers) >= total_num_layers_transformer:
|
||||
# remove first layers
|
||||
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
|
||||
else:
|
||||
# duplicate last layer
|
||||
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
if do_patching:
|
||||
num_channels_latents = int(num_channels_latents / 4)
|
||||
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
prompt_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
do_patching,
|
||||
)
|
||||
|
||||
latent_attention_mask = torch.ones(
|
||||
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
latent_attention_mask = latent_attention_mask.repeat(2, 1)
|
||||
|
||||
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
|
||||
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
|
||||
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
|
||||
|
||||
if self._joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
self._joint_attention_kwargs["attention_mask"] = attention_mask
|
||||
|
||||
# Adapt scheduler to dynamic shifting (resolution dependent)
|
||||
|
||||
if do_patching:
|
||||
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
|
||||
else:
|
||||
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
|
||||
|
||||
self.noise_scheduler, timesteps, num_inference_steps, mu = self.init_inference_scheduler(
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
num_inference_steps=num_inference_steps,
|
||||
noise_scheduler=self.scheduler,
|
||||
image_seq_len=seq_len,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Support old different diffusers versions
|
||||
if len(latent_image_ids.shape) == 3:
|
||||
latent_image_ids = latent_image_ids[0]
|
||||
|
||||
if len(text_ids.shape) == 3:
|
||||
text_ids = text_ids[0]
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(
|
||||
device=latent_model_input.device, dtype=latent_model_input.dtype
|
||||
)
|
||||
|
||||
# This is predicts "v" from flow-matching or eps from diffusion
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
text_encoder_layers=prompt_layers,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
if do_patching:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
else:
|
||||
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
|
||||
|
||||
latents = latents.to(dtype=self.vae.dtype)
|
||||
latents = latents.unsqueeze(dim=2)
|
||||
latents = list(torch.unbind(latents, dim=0))
|
||||
latents_device = latents[0].device
|
||||
latents_dtype = latents[0].dtype
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents_device, latents_dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents_device, latents_dtype
|
||||
)
|
||||
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
|
||||
latents_scaled = torch.cat(latents_scaled, dim=0)
|
||||
image = []
|
||||
for scaled_latent in latents_scaled:
|
||||
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
|
||||
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
|
||||
image.append(curr_image)
|
||||
if len(image) == 1:
|
||||
image = image[0]
|
||||
else:
|
||||
image = np.stack(image, axis=0)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return BriaFiboPipelineOutput(images=image)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 3000:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
DiffusionPipeline.to(self, *args, **kwargs)
|
||||
# We use as float32 since wan22 in their repo use it like this
|
||||
self.vae.to(dtype=torch.float32)
|
||||
return self
|
||||
21
src/diffusers/pipelines/bria_fibo/pipeline_output.py
Normal file
21
src/diffusers/pipelines/bria_fibo/pipeline_output.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriaFiboPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for BriaFibo pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
132
tests/models/transformers/test_models_transformer_bria_fibo.py
Normal file
132
tests/models/transformers/test_models_transformer_bria_fibo.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaFiboTransformer2DModel
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
|
||||
|
||||
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"text_encoder_layers": [encoder_hidden_states[:,:,:32], encoder_hidden_states[:,:,:32]],
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (256, 48)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 48,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 64,
|
||||
"text_encoder_dim": 32,
|
||||
"pooled_projection_dim": None,
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaFiboTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class BriaFiboTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class BriaFiboTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaFiboTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
0
tests/pipelines/bria_fibo/__init__.py
Normal file
0
tests/pipelines/bria_fibo/__init__.py
Normal file
198
tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
Normal file
198
tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
BriaFiboPipeline,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BriaFiboPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = BriaFiboTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=16,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=64,
|
||||
text_encoder_dim=32,
|
||||
pooled_projection_dim=None,
|
||||
axes_dims_rope=[0, 4, 4],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=160,
|
||||
decoder_base_dim=256,
|
||||
num_res_blocks=2,
|
||||
out_channels=12,
|
||||
patch_size=2,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
z_dim=16,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
|
||||
"negative_prompt": "bad, ugly",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
def test_bria_fibo_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = "a different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height
|
||||
expected_width = width
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_torch_accelerator
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
components[name] = module.to(torch_device).half()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for name, component in pipe_loaded.components.items():
|
||||
if name == "vae":
|
||||
continue
|
||||
if hasattr(component, "dtype"):
|
||||
self.assertTrue(
|
||||
component.dtype == torch.float16,
|
||||
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
|
||||
)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user