1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into unbloat-pipeline-utilities

This commit is contained in:
Sayak Paul
2025-11-01 08:03:55 +05:30
committed by GitHub
56 changed files with 2745 additions and 114 deletions

View File

@@ -323,6 +323,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -469,6 +471,8 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -525,6 +529,8 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models

View File

@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
## ChromaTransformer2DModel

View File

@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel

View File

@@ -0,0 +1,45 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__

View File

@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
> [!TIP]
> Chroma can use all the same optimizations as Flux.
## Inference
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = [
@@ -63,10 +64,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
model_id = "lodestones/Chroma"
model_id = "lodestones/Chroma1-HD"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

View File

@@ -0,0 +1,149 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
The model introduces several key innovations:
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
> [!TIP]
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
## Available Models
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
| model_id | Description | Use Cases |
|------------|-------------|-----------|
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
All models are available in 5-second and 10-second video generation versions.
## Kandinsky5T2VPipeline
[[autodoc]] Kandinsky5T2VPipeline
- all
- __call__
## Usage Examples
### Basic Text-to-Video Generation
```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### 10 second Models
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
```python
pipe = Kandinsky5T2VPipeline.from_pretrained(
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
) # <--- Compile with max-autotune-no-cudagraphs
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=241,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
output = pipe(
prompt="A beautiful sunset over mountains",
num_inference_steps=16, # <--- Model is distilled in 16 steps
guidance_scale=1.0, # <--- no CFG
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
## Citation
```bibtex
@misc{kandinsky2025,
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
year = 2025
}
```

View File

@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -198,6 +198,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -430,6 +431,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -901,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -1103,6 +1106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,

View File

@@ -17,12 +17,6 @@ from typing import Union
from ..utils import is_torch_available, logging
logger = logging.get_logger(__name__)
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance

View File

@@ -41,6 +41,10 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
self._start = start
self._stop = stop
self._step: int = None

View File

@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:

View File

@@ -84,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -174,6 +175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,

View File

@@ -27,6 +27,8 @@ if torch.distributed.is_available():
from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
@@ -47,6 +49,7 @@ if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
@@ -54,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -78,6 +82,12 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# `aiter`
AITER = "aiter"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],

View File

@@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -286,11 +286,9 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:

View File

@@ -18,6 +18,7 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel

View File

@@ -0,0 +1,655 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
class BriaFiboAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaFiboAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [BriaFiboAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class BriaFiboEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaFiboSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class BriaFiboTextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
class BriaFiboTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaFiboAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaFiboAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class BriaFiboTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class BriaFiboTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = BriaFiboTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
...
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
text_encoder_dim: int = 2048,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BriaFiboTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaFiboSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
caption_projection = [
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
text_encoder_layers: list = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
new_text_encoder_layers = []
for i, text_encoder_layer in enumerate(text_encoder_layers):
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
new_text_encoder_layers.append(text_encoder_layer)
text_encoder_layers = new_text_encoder_layers
block_id = 0
for index_block, block in enumerate(self.transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
"""
The Transformer model introduced in Flux, modified for Chroma.
Reference: https://huggingface.co/lodestones/Chroma
Reference: https://huggingface.co/lodestones/Chroma1-HD
Args:
patch_size (`int`, defaults to `1`):

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,7 +717,11 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

View File

@@ -324,6 +324,7 @@ class Kandinsky5AttnProcessor:
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
else:
attn_mask = None
@@ -335,6 +336,7 @@ class Kandinsky5AttnProcessor:
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(-2, -1)
attn_out = attn.out_layer(hidden_states)

View File

@@ -305,6 +305,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"cache_dir",
"force_download",
"local_files_only",
"local_dir",
"proxies",
"resume_download",
"revision",
@@ -331,11 +332,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
}
return block_cls(**block_kwargs)

View File

@@ -128,6 +128,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,838 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Example:
```python
import torch
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
pipe = BriaFiboPipeline.from_pretrained(
"briaai/FIBO",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
with torch.inference_mode():
# 1. Create a prompt to generate an initial image
output = vlm_pipe(prompt="a beautiful dog")
json_prompt_generate = output.values["json_prompt"]
# Generate the image from the structured json prompt
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
results_generate.images[0].save("image_generate.png")
```
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer (`BriaFiboTransformer2DModel`):
The transformer model for 2D diffusion modeling.
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
Scheduler to be used with `transformer` to denoise the encoded latents.
vae (`AutoencoderKLWan`):
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
text_encoder (`SmolLM3ForCausalLM`):
Text encoder for processing input prompts.
tokenizer (`AutoTokenizer`):
Tokenizer used for processing the input text prompts for the text_encoder.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaFiboTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKLWan,
text_encoder: SmolLM3ForCausalLM,
tokenizer: AutoTokenizer,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 64
def get_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
if not prompt:
raise ValueError("`prompt` must be a non-empty string or list of strings.")
batch_size = len(prompt)
bot_token_id = 128000
text_encoder_device = device if device is not None else torch.device("cpu")
if not isinstance(text_encoder_device, torch.device):
text_encoder_device = torch.device(text_encoder_device)
if all(p == "" for p in prompt):
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
attention_mask = torch.ones_like(input_ids)
else:
tokenized = self.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized.input_ids.to(text_encoder_device)
attention_mask = tokenized.attention_mask.to(text_encoder_device)
if any(p == "" for p in prompt):
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
input_ids[empty_rows] = bot_token_id
attention_mask[empty_rows] = 1
encoder_outputs = self.text_encoder(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_outputs.hidden_states
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hidden_states = tuple(
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
)
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
return prompt_embeds, hidden_states, attention_mask
@staticmethod
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
batch_size, seq_len, dim = prompt_embeds.shape
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
else:
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if max_tokens < seq_len:
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
if max_tokens > seq_len:
pad_length = max_tokens - seq_len
padding = torch.zeros(
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
mask_padding = torch.zeros(
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
return prompt_embeds, attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 3000,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
guidance_scale (`float`):
Guidance scale for classifier free guidance.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_attention_mask = None
negative_prompt_attention_mask = None
if prompt_embeds is None:
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if guidance_scale > 1:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
# Pad to longest
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if negative_prompt_embeds is not None:
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
)
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
)
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
else:
max_tokens = prompt_embeds.shape[1]
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
negative_prompt_layers = None
dtype = self.text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
return (
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@staticmethod
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels)
latents = latents.permute(0, 3, 1, 2)
return latents
@staticmethod
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
latents = latents.permute(0, 2, 3, 1)
latents = latents.reshape(batch_size, height * width, num_channels_latents)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
do_patching=False,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if do_patching:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
else:
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@staticmethod
def _prepare_attention_mask(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
attention_matrix = torch.where(
attention_matrix == 1, 0.0, -torch.inf
) # Apply -inf to ignored tokens for nulling softmax score
return attention_matrix
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 3000,
do_patching=False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_images_per_prompt=num_images_per_prompt,
lora_scale=lora_scale,
)
prompt_batch_size = prompt_embeds.shape[0]
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
]
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
self.transformer.single_transformer_blocks
)
if len(prompt_layers) >= total_num_layers_transformer:
# remove first layers
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
else:
# duplicate last layer
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
latents, latent_image_ids = self.prepare_latents(
prompt_batch_size,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
do_patching,
)
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if guidance_scale > 1:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
self._joint_attention_kwargs["attention_mask"] = attention_mask
# Adapt scheduler to dynamic shifting (resolution dependent)
if do_patching:
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Support old different diffusers versions
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
device=latent_model_input.device, dtype=latent_model_input.dtype
)
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_encoder_layers=prompt_layers,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if guidance_scale > 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
if do_patching:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.unsqueeze(dim=2)
latents_device = latents[0].device
latents_dtype = latents[0].dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents_device, latents_dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_device, latents_dtype
)
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
image.append(curr_image)
if len(image) == 1:
image = image[0]
else:
image = np.stack(image, axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaFiboPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaFiboPipelineOutput(BaseOutput):
"""
Output class for BriaFibo pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -52,8 +52,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
@@ -83,7 +83,7 @@ class ChromaPipeline(
r"""
The Chroma pipeline for text-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -158,20 +158,23 @@ class ChromaPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
tokenizer_mask_device = tokenizer_mask.to(device)
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape

View File

@@ -52,8 +52,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id,
... transformer=transformer,
@@ -81,7 +81,7 @@ class ChromaImg2ImgPipeline(
r"""
The Chroma pipeline for image-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -158,20 +158,21 @@ class ChromaImg2ImgPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
tokenizer_mask_device = tokenizer_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape

View File

@@ -205,7 +205,7 @@ class StableDiffusion3ControlNetPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -294,7 +294,7 @@ class StableDiffusion3ControlNetPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -223,7 +223,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -312,7 +312,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -173,8 +173,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
)
self.prompt_template_encode_start_idx = 129
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@staticmethod
@@ -384,6 +386,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
prompt = [prompt_clean(p) for p in prompt]

View File

@@ -176,7 +176,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -265,7 +265,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -178,7 +178,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -267,7 +267,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -173,7 +173,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -261,7 +261,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -183,7 +183,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -189,7 +189,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds

View File

@@ -138,34 +138,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
@@ -322,7 +324,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if self.transformer is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
elif self.transformer_2 is not None:
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
else:
raise ValueError(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
@@ -400,7 +410,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
):
if video is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
base = self.vae_scale_factor_spatial * (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width:
@@ -575,7 +589,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
)
transformer_patch_size = self.transformer.config.patch_size[1]
transformer_patch_size = (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -830,20 +848,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
vace_layers = (
self.transformer.config.vace_layers
if self.transformer is not None
else self.transformer_2.config.vace_layers
)
if isinstance(conditioning_scale, (int, float)):
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list):
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
if len(conditioning_scale) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor):
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
if conditioning_scale.size(0) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
@@ -886,7 +909,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -954,7 +981,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

View File

@@ -64,6 +64,8 @@ from .import_utils import (
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,

View File

@@ -588,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BriaFiboTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -482,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -254,6 +254,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
local_dir: Optional[str] = None,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -332,6 +333,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
@@ -355,6 +357,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
token=token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -415,6 +418,7 @@ def get_cached_module_file(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return os.path.join(full_submodule, module_file)
@@ -431,7 +435,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
local_dir: Optional[str] = None,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -496,5 +500,6 @@ def get_class_from_dynamic_module(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return get_class_in_module(class_name, final_module)

View File

@@ -226,6 +226,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
@@ -406,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
def is_aiter_available():
return _aiter_available
def is_kornia_available():
return _kornia_available
@@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version)
@cache
def is_aiter_version(operation: str, version: str):
"""
Compares the current aiter version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _aiter_available:
return False
return compare_versions(parse(_aiter_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects

View File

@@ -20,7 +20,7 @@ from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from ..testing_utils import floats_tensor, require_peft_backend
from ..testing_utils import IS_GITHUB_ACTIONS, floats_tensor, require_peft_backend
sys.path.append(".")
@@ -136,3 +136,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()

View File

@@ -17,11 +17,7 @@ import unittest
from diffusers import AutoencoderDC
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
@@ -82,3 +78,7 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()

View File

@@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaFiboTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (256, 48)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"joint_attention_dim": 64,
"text_encoder_dim": 32,
"pooled_projection_dim": None,
"axes_dims_rope": [0, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -14,6 +14,10 @@ pytest tests/others/test_attention_backends.py
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
"""
import os
@@ -44,6 +48,10 @@ FORWARD_CASES = [
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
]
COMPILE_CASES = [
@@ -63,6 +71,11 @@ COMPILE_CASES = [
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
]
# fmt: on

View File

View File

@@ -0,0 +1,139 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=160,
decoder_base_dim=256,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"output_type": "np",
}
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)

View File

View File

@@ -0,0 +1,306 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLProcessor,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
Kandinsky5T2VPipeline,
Kandinsky5Transformer3DModel,
)
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Kandinsky5T2VPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
# Define required optional parameters for your pipeline
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
"max_sequence_length",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
latent_channels=4,
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
norm_num_groups=4,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
# Dummy Qwen2.5-VL model
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
# Dummy CLIP model
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
transformer = Kandinsky5Transformer3DModel(
in_visual_dim=4,
in_text_dim=16, # Match tiny Qwen2.5-VL hidden size
in_text_dim2=32, # Match tiny CLIP hidden size
time_dim=32,
out_visual_dim=4,
patch_size=(1, 2, 2),
model_dim=48,
ff_dim=128,
num_text_blocks=1,
num_visual_blocks=1,
axes_dims=(8, 8, 8),
visual_cond=False,
)
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder.eval(),
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2.eval(),
"tokenizer_2": tokenizer_2,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A cat dancing",
"negative_prompt": "blurry, low quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"num_frames": 5,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
# Check video shape: (batch, frames, channel, height, width)
expected_shape = (1, 5, 3, 32, 32)
self.assertEqual(video.shape, expected_shape)
# Check specific values
expected_slice = torch.tensor(
[
0.4330,
0.4254,
0.4285,
0.3835,
0.4253,
0.4196,
0.3704,
0.3714,
0.4999,
0.5346,
0.4795,
0.4637,
0.4930,
0.5124,
0.4902,
0.4570,
]
)
generated_slice = video.flatten()
# Take first 8 and last 8 values for comparison
video_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(video_slice, expected_slice, atol=1e-3),
f"video_slice: {video_slice}, expected_slice: {expected_slice}",
)
def test_inference_batch_single_identical(self):
# Override to test batch single identical with video
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-3, rtol=1e-3):
components = self.get_dummy_components()
text_component_names = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]
text_components = {k: (v if k in text_component_names else None) for k, v in components.items()}
non_text_components = {k: (v if k not in text_component_names else None) for k, v in components.items()}
pipe_with_just_text_encoder = self.pipeline_class(**text_components)
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
pipe_without_text_encoders = self.pipeline_class(**non_text_components)
pipe_without_text_encoders = pipe_without_text_encoders.to(torch_device)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
# Compute `encode_prompt()`.
# Test single prompt
prompt = "A cat dancing"
with torch.no_grad():
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
prompt, device=torch_device, max_sequence_length=16
)
# Check shapes
self.assertEqual(prompt_embeds_qwen.shape, (1, 4, 16)) # [batch, seq_len, embed_dim]
self.assertEqual(prompt_embeds_clip.shape, (1, 32)) # [batch, embed_dim]
self.assertEqual(prompt_cu_seqlens.shape, (2,)) # [batch + 1]
# Test batch of prompts
prompts = ["A cat dancing", "A dog running"]
with torch.no_grad():
batch_embeds_qwen, batch_embeds_clip, batch_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
prompts, device=torch_device, max_sequence_length=16
)
# Check batch size
self.assertEqual(batch_embeds_qwen.shape, (len(prompts), 4, 16))
self.assertEqual(batch_embeds_clip.shape, (len(prompts), 32))
self.assertEqual(len(batch_cu_seqlens), len(prompts) + 1) # [0, len1, len1+len2]
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
# baseline output: full pipeline
pipe_out = pipe(**inputs).frames
# test against pipeline call with pre-computed prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
with torch.no_grad():
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
inputs["prompt"], device=torch_device, max_sequence_length=inputs["max_sequence_length"]
)
inputs["prompt"] = None
inputs["prompt_embeds_qwen"] = prompt_embeds_qwen
inputs["prompt_embeds_clip"] = prompt_embeds_clip
inputs["prompt_cu_seqlens"] = prompt_cu_seqlens
pipe_out_2 = pipe_without_text_encoders(**inputs)[0]
self.assertTrue(
torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol),
f"max diff: {torch.max(torch.abs(pipe_out - pipe_out_2))}",
)
@unittest.skip("Kandinsky5T2VPipeline does not support attention slicing")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Kandinsky5T2VPipeline does not support xformers")
def test_xformers_attention_forwardGenerator_pass(self):
pass
@unittest.skip("Kandinsky5T2VPipeline does not support VAE slicing")
def test_vae_slicing(self):
pass

View File

@@ -23,6 +23,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from ...testing_utils import (
IS_GITHUB_ACTIONS,
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -304,6 +305,10 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@slow
@require_torch_accelerator

View File

@@ -28,10 +28,7 @@ from diffusers import (
)
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -326,3 +323,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()

View File

@@ -21,10 +21,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -300,3 +297,7 @@ class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()

View File

@@ -22,10 +22,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -312,3 +309,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()

View File

@@ -1461,6 +1461,8 @@ class PipelineTesterMixin:
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
if not self.pipeline_class._optional_components:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -19,9 +20,15 @@ import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
WanVACEPipeline,
WanVACETransformer3DModel,
)
from ...testing_utils import enable_full_determinism
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
def test_save_load_float16(self):
pass
def test_inference_with_only_transformer(self):
components = self.get_dummy_components()
components["transformer_2"] = None
components["boundary_ratio"] = 0.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_inference_with_only_transformer_2(self):
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
components["transformer"] = None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer"]
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
for component in optional_component:
components[component] = None
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"

View File

@@ -63,6 +63,8 @@ else:
IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False
IS_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" and os.getenv("DIFFUSERS_IS_CI") == "yes"
global_rng = random.Random()
logger = get_logger(__name__)