From 500b9cf184d9a6aa950ccbb9a4d967b95877196d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 26 Oct 2025 23:11:23 +0530 Subject: [PATCH 01/16] [chore] Move guiders experimental warning (#12543) * move guiders experimental warning to init. * up --- src/diffusers/guiders/__init__.py | 6 ------ src/diffusers/guiders/guider_utils.py | 4 ++++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 6c95b8e0c4..4e53c373c4 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -17,12 +17,6 @@ from typing import Union from ..utils import is_torch_available, logging -logger = logging.get_logger(__name__) -logger.warning( - "Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases." -) - - if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index b7afd8648b..71e4becfcd 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -41,6 +41,10 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): _identifier_key = "__guidance_identifier__" def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True): + logger.warning( + "Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases." + ) + self._start = start self._stop = stop self._step: int = None From dc6bd1511a4948ebca35b22609002bba58e71c83 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:55:20 +0800 Subject: [PATCH 02/16] Fix Chroma attention padding order and update docs to use `lodestones/Chroma1-HD` (#12508) * [Fix] Move attention mask padding after T5 embedding * [Fix] Move attention mask padding after T5 embedding * Clean up whitespace in pipeline_chroma.py Removed unnecessary blank lines for cleaner code. * Fix * Fix * Update model to final Chroma1-HD checkpoint * Update to Chroma1-HD * Update model to Chroma1-HD * Update model to Chroma1-HD * Update Chroma model links to Chroma1-HD * Add comment about padding/masking * Fix checkpoint/repo references * Apply style fixes --------- Co-authored-by: github-actions[bot] Co-authored-by: Dhruv Nair --- .../en/api/models/chroma_transformer.md | 2 +- docs/source/en/api/pipelines/chroma.md | 13 +++++----- .../models/transformers/transformer_chroma.py | 2 +- .../pipelines/chroma/pipeline_chroma.py | 25 +++++++++++-------- .../chroma/pipeline_chroma_img2img.py | 23 +++++++++-------- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/docs/source/en/api/models/chroma_transformer.md b/docs/source/en/api/models/chroma_transformer.md index 681e81f7a5..1ef24cda39 100644 --- a/docs/source/en/api/models/chroma_transformer.md +++ b/docs/source/en/api/models/chroma_transformer.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # ChromaTransformer2DModel -A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma) +A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD) ## ChromaTransformer2DModel diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index df03fbb325..cc52ffa09a 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License. Chroma is a text to image generation model based on Flux. -Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma). +Original model checkpoints for Chroma can be found here: +* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD) +* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) +* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint) > [!TIP] > Chroma can use all the same optimizations as Flux. ## Inference -The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma). - ```python import torch from diffusers import ChromaPipeline -pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16) +pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() prompt = [ @@ -63,10 +64,10 @@ Then run the following example import torch from diffusers import ChromaTransformer2DModel, ChromaPipeline -model_id = "lodestones/Chroma" +model_id = "lodestones/Chroma1-HD" dtype = torch.bfloat16 -transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype) +transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype) pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype) pipe.enable_model_cpu_offload() diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 5823ae9d3d..2ef3643daf 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -379,7 +379,7 @@ class ChromaTransformer2DModel( """ The Transformer model introduced in Flux, modified for Chroma. - Reference: https://huggingface.co/lodestones/Chroma + Reference: https://huggingface.co/lodestones/Chroma1-HD Args: patch_size (`int`, defaults to `1`): diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 5482035b3a..ed6c2c2105 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import ChromaPipeline - >>> model_id = "lodestones/Chroma" - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) >>> pipe = ChromaPipeline.from_pretrained( ... model_id, @@ -158,7 +158,7 @@ class ChromaPipeline( r""" The Chroma pipeline for text-to-image generation. - Reference: https://huggingface.co/lodestones/Chroma/ + Reference: https://huggingface.co/lodestones/Chroma1-HD/ Args: transformer ([`ChromaTransformer2DModel`]): @@ -233,20 +233,23 @@ class ChromaPipeline( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask.clone() + tokenizer_mask = text_inputs.attention_mask - # Chroma requires the attention mask to include one padding token - seq_lengths = attention_mask.sum(dim=1) - mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool() + tokenizer_mask_device = tokenizer_mask.to(device) + # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding prompt_embeds = self.text_encoder( - text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, )[0] - dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(device=device) + + # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 9afd4b9e15..470c746e41 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline - >>> model_id = "lodestones/Chroma" - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" >>> pipe = ChromaImg2ImgPipeline.from_pretrained( ... model_id, ... transformer=transformer, @@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline( r""" The Chroma pipeline for image-to-image generation. - Reference: https://huggingface.co/lodestones/Chroma/ + Reference: https://huggingface.co/lodestones/Chroma1-HD/ Args: transformer ([`ChromaTransformer2DModel`]): @@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask.clone() + tokenizer_mask = text_inputs.attention_mask - # Chroma requires the attention mask to include one padding token - seq_lengths = attention_mask.sum(dim=1) - mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long() + tokenizer_mask_device = tokenizer_mask.to(device) prompt_embeds = self.text_encoder( - text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, )[0] - dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(dtype=dtype, device=device) + + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape From 250f5cb53db1554f32dee07ad002f6c3834306d0 Mon Sep 17 00:00:00 2001 From: Mikko Lauri <167742591+lauri9@users.noreply.github.com> Date: Mon, 27 Oct 2025 16:55:02 +0200 Subject: [PATCH 03/16] Add AITER attention backend (#12549) * add aiter attention backend * Apply style fixes --------- Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- .../en/optimization/attention_backends.md | 2 + src/diffusers/models/attention_dispatch.py | 60 +++++++++++++++++++ src/diffusers/utils/__init__.py | 2 + src/diffusers/utils/import_utils.py | 21 +++++++ tests/others/test_attention_backends.py | 13 ++++ 5 files changed, 98 insertions(+) diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index 8be2c06030..edfdcc38b5 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and | attention family | main feature | |---|---| | FlashAttention | minimizes memory reads/writes through tiling and recomputation | +| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators | | SageAttention | quantizes attention to int8 | | PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) | | xFormers | memory-efficient attention with support for various attention kernels | @@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and | `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention | | `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 | | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | +| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e169491099..ab0d7102ee 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -27,6 +27,8 @@ if torch.distributed.is_available(): from ..utils import ( get_logger, + is_aiter_available, + is_aiter_version, is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, @@ -47,6 +49,7 @@ if TYPE_CHECKING: from ._modeling_parallel import ParallelConfig _REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_AITER_VERSION = "0.1.5" _REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_XLA_VERSION = "2.2" @@ -54,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29" _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) _CAN_USE_NPU_ATTN = is_torch_npu_available() @@ -78,6 +82,12 @@ else: flash_attn_3_func = None flash_attn_3_varlen_func = None + +if _CAN_USE_AITER_ATTN: + from aiter import flash_attn_func as aiter_flash_attn_func +else: + aiter_flash_attn_func = None + if DIFFUSERS_ENABLE_HUB_KERNELS: if not is_kernels_available(): raise ImportError( @@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum): _FLASH_3_HUB = "_flash_3_hub" # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + # `aiter` + AITER = "aiter" + # PyTorch native FLEX = "flex" NATIVE = "native" @@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) + elif backend == AttentionBackendName.AITER: + if not _CAN_USE_AITER_ATTN: + raise RuntimeError( + f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." + ) + elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, @@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.AITER, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _aiter_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if not return_lse and torch.is_grad_enabled(): + # aiter requires return_lse=True by assertion when gradients are enabled. + out, lse, *_ = aiter_flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_lse=True, + ) + else: + out = aiter_flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d8e1a55401..cf77aaee82 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -64,6 +64,8 @@ from .import_utils import ( get_objects_from_module, is_accelerate_available, is_accelerate_version, + is_aiter_available, + is_aiter_version, is_better_profanity_available, is_bitsandbytes_available, is_bitsandbytes_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 97065267b0..adf8ed8b06 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -226,6 +226,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available(" _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) @@ -406,6 +407,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_aiter_available(): + return _aiter_available + + def is_kornia_available(): return _kornia_available @@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str): return compare_versions(parse(_flash_attn_version), operation, version) +@cache +def is_aiter_version(operation: str, version: str): + """ + Compares the current aiter version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _aiter_available: + return False + return compare_versions(parse(_aiter_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 42cdcd56f7..2e5a2fc82b 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -14,6 +14,10 @@ pytest tests/others/test_attention_backends.py Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). + +Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X +with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and +aiter 0.1.5.post4.dev20+ga25e55e79. """ import os @@ -44,6 +48,10 @@ FORWARD_CASES = [ "_native_cudnn", torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), ), + ( + "aiter", + torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), + ) ] COMPILE_CASES = [ @@ -63,6 +71,11 @@ COMPILE_CASES = [ torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), True, ), + ( + "aiter", + torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), + True, + ) ] # fmt: on From 6d1a6486024192951ce696e8f4cf79a39509182f Mon Sep 17 00:00:00 2001 From: alirezafarashah <63104907+alirezafarashah@users.noreply.github.com> Date: Mon, 27 Oct 2025 13:16:43 -0400 Subject: [PATCH 04/16] Fix small inconsistency in output dimension of "_get_t5_prompt_embeds" function in sd3 pipeline (#12531) * Fix small inconsistency in output dimension of t5 embeds when text_encoder_3 is None * first commit --------- Co-authored-by: Alireza Farashah Co-authored-by: Alireza Farashah --- .../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py | 2 +- .../pipeline_stable_diffusion_3_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/pag/pipeline_pag_sd_3.py | 2 +- src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py | 2 +- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index c763411ab5..f67a0e2112 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline( return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index c33cf979c6..68984da4dc 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline( return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index acb4e52340..bc281428e2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index e1819a79fb..22a8dac238 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 1618f89a49..3b7b26dc63 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 7e97909f42..db047f1992 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 5b2cca0378..c95fa530c8 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro return torch.zeros( ( batch_size * num_images_per_prompt, - self.tokenizer_max_length, + max_sequence_length, self.transformer.config.joint_attention_dim, ), device=device, From 5afbcce176cd4e8ec08f43ee9fae2d6562edf54c Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 28 Oct 2025 05:17:18 +0300 Subject: [PATCH 05/16] Kandinsky 5 10 sec (NABLA suport) (#12520) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add transformer pipeline first version * updates * fix 5sec generation * rewrite Kandinsky5T2VPipeline to diffusers style * add multiprompt support * remove prints in pipeline * add nabla attention * Wrap Transformer in Diffusers style * fix license * fix prompt type * add gradient checkpointing and peft support * add usage example * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: Álvaro Somoza * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Álvaro Somoza * remove unused imports * add 10 second models support * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * remove no_grad and simplified prompt paddings * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * moved template to __init__ * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * moved sdps inside processor * remove oneline function * remove reset_dtype methods * Transformer: move all methods to forward * separated prompt encoding * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * refactoring * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1 * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu * fixed * style +copies * Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Charles * more * Apply suggestions from code review * add lora loader doc * add compiled Nabla Attention * all needed changes for 10 sec models are added! * add docs * Apply style fixes * update docs * add kandinsky5 to toctree * add tests * fix tests * Apply style fixes * update tests --------- Co-authored-by: Álvaro Somoza Co-authored-by: YiYi Xu Co-authored-by: Charles Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/kandinsky5.md | 149 +++++++++ .../transformers/transformer_kandinsky.py | 2 + .../kandinsky5/pipeline_kandinsky.py | 9 +- tests/pipelines/kandinsky5/__init__.py | 0 tests/pipelines/kandinsky5/test_kandinsky5.py | 306 ++++++++++++++++++ tests/pipelines/test_pipelines_common.py | 2 + 7 files changed, 468 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/pipelines/kandinsky5.md create mode 100644 tests/pipelines/kandinsky5/__init__.py create mode 100644 tests/pipelines/kandinsky5/test_kandinsky5.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 540e99a2c6..44870f680e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -525,6 +525,8 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 + - local: api/pipelines/kandinsky5 + title: Kandinsky 5 - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models diff --git a/docs/source/en/api/pipelines/kandinsky5.md b/docs/source/en/api/pipelines/kandinsky5.md new file mode 100644 index 0000000000..cb1c119f80 --- /dev/null +++ b/docs/source/en/api/pipelines/kandinsky5.md @@ -0,0 +1,149 @@ + + +# Kandinsky 5.0 + +Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov + + +Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem. + +The model introduces several key innovations: +- **Latent diffusion pipeline** with **Flow Matching** for improved training stability +- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings +- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding +- **HunyuanVideo 3D VAE** for efficient video encoding and decoding +- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing + +The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5). + +> [!TIP] +> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants. + +## Available Models + +Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases: + +| model_id | Description | Use Cases | +|------------|-------------|-----------| +| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality | +| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality | +| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference | +| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference | +| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | +| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | +| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning | +| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning | + +All models are available in 5-second and 10-second video generation versions. + +## Kandinsky5T2VPipeline + +[[autodoc]] Kandinsky5T2VPipeline + - all + - __call__ + +## Usage Examples + +### Basic Text-to-Video Generation + +```python +import torch +from diffusers import Kandinsky5T2VPipeline +from diffusers.utils import export_to_video + +# Load the pipeline +model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" +pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +# Generate video +prompt = "A cat and a dog baking a cake together in a kitchen." +negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=512, + width=768, + num_frames=121, # ~5 seconds at 24fps + num_inference_steps=50, + guidance_scale=5.0, +).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) +``` + +### 10 second Models +**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation: + +```python +pipe = Kandinsky5T2VPipeline.from_pretrained( + "ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers", + torch_dtype=torch.bfloat16 +) +pipe = pipe.to("cuda") + +pipe.transformer.set_attention_backend( + "flex" +) # <--- Sett attention bakend to Flex +pipe.transformer.compile( + mode="max-autotune-no-cudagraphs", + dynamic=True +) # <--- Compile with max-autotune-no-cudagraphs + +prompt = "A cat and a dog baking a cake together in a kitchen." +negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=512, + width=768, + num_frames=241, + num_inference_steps=50, + guidance_scale=5.0, +).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) +``` + +### Diffusion Distilled model +**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): + +```python +model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" +pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +output = pipe( + prompt="A beautiful sunset over mountains", + num_inference_steps=16, # <--- Model is distilled in 16 steps + guidance_scale=1.0, # <--- no CFG +).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) +``` + + +## Citation +```bibtex +@misc{kandinsky2025, + author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and + Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and + Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and + Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and + Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and + Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and + Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov}, + title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation}, + howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}}, + year = 2025 +} +``` \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index d4ba92acaf..316e79da4f 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -324,6 +324,7 @@ class Kandinsky5AttnProcessor: sparse_params["sta_mask"], thr=sparse_params["P"], ) + else: attn_mask = None @@ -335,6 +336,7 @@ class Kandinsky5AttnProcessor: backend=self._attention_backend, parallel_config=self._parallel_config, ) + hidden_states = hidden_states.flatten(-2, -1) attn_out = attn.out_layer(hidden_states) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 2b977a5a36..3f93aa1889 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -173,8 +173,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): ) self.prompt_template_encode_start_idx = 129 - self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio - self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @staticmethod @@ -384,6 +386,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) prompt = [prompt_clean(p) for p in prompt] diff --git a/tests/pipelines/kandinsky5/__init__.py b/tests/pipelines/kandinsky5/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/kandinsky5/test_kandinsky5.py b/tests/pipelines/kandinsky5/test_kandinsky5.py new file mode 100644 index 0000000000..47fccb632a --- /dev/null +++ b/tests/pipelines/kandinsky5/test_kandinsky5.py @@ -0,0 +1,306 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLProcessor, +) + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + Kandinsky5T2VPipeline, + Kandinsky5Transformer3DModel, +) + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Kandinsky5T2VPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + # Define required optional parameters for your pipeline + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + "max_sequence_length", + ] + ) + + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + latent_channels=4, + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + norm_num_groups=4, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + # Dummy Qwen2.5-VL model + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + # Dummy CLIP model + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + transformer = Kandinsky5Transformer3DModel( + in_visual_dim=4, + in_text_dim=16, # Match tiny Qwen2.5-VL hidden size + in_text_dim2=32, # Match tiny CLIP hidden size + time_dim=32, + out_visual_dim=4, + patch_size=(1, 2, 2), + model_dim=48, + ff_dim=128, + num_text_blocks=1, + num_visual_blocks=1, + axes_dims=(8, 8, 8), + visual_cond=False, + ) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2.eval(), + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A cat dancing", + "negative_prompt": "blurry, low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "num_frames": 5, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + + # Check video shape: (batch, frames, channel, height, width) + expected_shape = (1, 5, 3, 32, 32) + self.assertEqual(video.shape, expected_shape) + + # Check specific values + expected_slice = torch.tensor( + [ + 0.4330, + 0.4254, + 0.4285, + 0.3835, + 0.4253, + 0.4196, + 0.3704, + 0.3714, + 0.4999, + 0.5346, + 0.4795, + 0.4637, + 0.4930, + 0.5124, + 0.4902, + 0.4570, + ] + ) + + generated_slice = video.flatten() + # Take first 8 and last 8 values for comparison + video_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(video_slice, expected_slice, atol=1e-3), + f"video_slice: {video_slice}, expected_slice: {expected_slice}", + ) + + def test_inference_batch_single_identical(self): + # Override to test batch single identical with video + super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2) + + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-3, rtol=1e-3): + components = self.get_dummy_components() + + text_component_names = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] + text_components = {k: (v if k in text_component_names else None) for k, v in components.items()} + non_text_components = {k: (v if k not in text_component_names else None) for k, v in components.items()} + + pipe_with_just_text_encoder = self.pipeline_class(**text_components) + pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device) + + pipe_without_text_encoders = self.pipeline_class(**non_text_components) + pipe_without_text_encoders = pipe_without_text_encoders.to(torch_device) + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + + # Compute `encode_prompt()`. + + # Test single prompt + prompt = "A cat dancing" + with torch.no_grad(): + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt( + prompt, device=torch_device, max_sequence_length=16 + ) + + # Check shapes + self.assertEqual(prompt_embeds_qwen.shape, (1, 4, 16)) # [batch, seq_len, embed_dim] + self.assertEqual(prompt_embeds_clip.shape, (1, 32)) # [batch, embed_dim] + self.assertEqual(prompt_cu_seqlens.shape, (2,)) # [batch + 1] + + # Test batch of prompts + prompts = ["A cat dancing", "A dog running"] + with torch.no_grad(): + batch_embeds_qwen, batch_embeds_clip, batch_cu_seqlens = pipe_with_just_text_encoder.encode_prompt( + prompts, device=torch_device, max_sequence_length=16 + ) + + # Check batch size + self.assertEqual(batch_embeds_qwen.shape, (len(prompts), 4, 16)) + self.assertEqual(batch_embeds_clip.shape, (len(prompts), 32)) + self.assertEqual(len(batch_cu_seqlens), len(prompts) + 1) # [0, len1, len1+len2] + + inputs = self.get_dummy_inputs(torch_device) + inputs["guidance_scale"] = 1.0 + + # baseline output: full pipeline + pipe_out = pipe(**inputs).frames + + # test against pipeline call with pre-computed prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["guidance_scale"] = 1.0 + + with torch.no_grad(): + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt( + inputs["prompt"], device=torch_device, max_sequence_length=inputs["max_sequence_length"] + ) + + inputs["prompt"] = None + inputs["prompt_embeds_qwen"] = prompt_embeds_qwen + inputs["prompt_embeds_clip"] = prompt_embeds_clip + inputs["prompt_cu_seqlens"] = prompt_cu_seqlens + + pipe_out_2 = pipe_without_text_encoders(**inputs)[0] + + self.assertTrue( + torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol), + f"max diff: {torch.max(torch.abs(pipe_out - pipe_out_2))}", + ) + + @unittest.skip("Kandinsky5T2VPipeline does not support attention slicing") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Kandinsky5T2VPipeline does not support xformers") + def test_xformers_attention_forwardGenerator_pass(self): + pass + + @unittest.skip("Kandinsky5T2VPipeline does not support VAE slicing") + def test_vae_slicing(self): + pass diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index db8209835b..2af4ad0314 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1461,6 +1461,8 @@ class PipelineTesterMixin: def test_save_load_optional_components(self, expected_max_difference=1e-4): if not hasattr(self.pipeline_class, "_optional_components"): return + if not self.pipeline_class._optional_components: + return components = self.get_dummy_components() pipe = self.pipeline_class(**components) for component in pipe.components.values(): From 303efd2b8d9949758004158c89b6c29a816150a2 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:55:36 +0800 Subject: [PATCH 06/16] Improve pos embed for Flux.1 inference on Ascend NPU (#12534) improve pos embed for ascend npu Co-authored-by: felix01.yu --- src/diffusers/models/transformers/transformer_flux.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a44644324..16c526f437 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -717,7 +717,11 @@ class FluxTransformer2DModel( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") From df0e2a4f2c2ebd7b93e6f51602d8d054afd8cd86 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Oct 2025 08:55:24 +0530 Subject: [PATCH 07/16] support latest few-step wan LoRA. (#12541) * support latest few-step wan LoRA. * up * up --- .../loaders/lora_conversion_utils.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 89afb6529a..099dbfc1d2 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): "time_projection.1.diff_b" ) - if any("head.head" in k for k in state_dict): - converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( - f"head.head.{lora_down_key}.weight" - ) - converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") + if any("head.head" in k for k in original_state_dict): + if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict): + converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( + f"head.head.{lora_down_key}.weight" + ) + if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict): + converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop( + f"head.head.{lora_up_key}.weight" + ) if "head.head.diff_b" in original_state_dict: converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b") + # Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras + # This is my (sayakpaul) assumption that this particular key belongs to the down matrix. + # Since for this particular LoRA, we don't have the corresponding up matrix, I will use + # an identity. + if any("head.head" in k and k.endswith(".diff") for k in state_dict): + if f"head.head.{lora_down_key}.weight" in state_dict: + logger.info( + f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected." + ) + converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff") + down_matrix_head = converted_state_dict["proj_out.lora_A.weight"] + up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0]) + converted_state_dict["proj_out.lora_B.weight"] = torch.eye( + *up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device + ).T + for text_time in ["text_embedding", "time_embedding"]: if any(text_time in k for k in original_state_dict): for b_n in [0, 2]: From ecfbc8f952fb20b47d2f85695e98d64c2f5b219a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 28 Oct 2025 09:21:31 +0530 Subject: [PATCH 08/16] [Pipelines] Enable Wan VACE to run since single transformer (#12428) * update * update * update * update * update --- .../pipelines/wan/pipeline_wan_vace.py | 73 ++++++++++----- tests/pipelines/wan/test_wan_vace.py | 89 ++++++++++++++++++- 2 files changed, 137 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 2b1890afec..63e557a98f 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanVACETransformer3DModel`]): - Conditional Transformer to denoise the input latents. - transformer_2 ([`WanVACETransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, - `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only - `transformer` is used. - scheduler ([`UniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + transformer ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. + transformer_2 ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < - boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + boundary_timestep. If `None`, only the available transformer is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer_2"] + _optional_components = ["transformer", "transformer_2"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanVACETransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanVACETransformer3DModel = None, transformer_2: WanVACETransformer3DModel = None, boundary_ratio: Optional[float] = None, ): @@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): reference_images=None, guidance_scale_2=None, ): - base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + if self.transformer is not None: + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + elif self.transformer_2 is not None: + base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] + else: + raise ValueError( + "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" + ) + if height % base != 0 or width % base != 0: raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") @@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): device: Optional[torch.device] = None, ): if video is not None: - base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + base = self.vae_scale_factor_spatial * ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) video_height, video_width = self.video_processor.get_default_height_width(video[0]) if video_height * video_width > height * width: @@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): "Generating with more than one video is not yet supported. This may be supported in the future." ) - transformer_patch_size = self.transformer.config.patch_size[1] + transformer_patch_size = ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) mask_list = [] for mask_, reference_images_batch in zip(mask, reference_images): @@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): batch_size = prompt_embeds.shape[0] vae_dtype = self.vae.dtype - transformer_dtype = self.transformer.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + vace_layers = ( + self.transformer.config.vace_layers + if self.transformer is not None + else self.transformer_2.config.vace_layers + ) if isinstance(conditioning_scale, (int, float)): - conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers) + conditioning_scale = [conditioning_scale] * len(vace_layers) if isinstance(conditioning_scale, list): - if len(conditioning_scale) != len(self.transformer.config.vace_layers): + if len(conditioning_scale) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}." + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." ) conditioning_scale = torch.tensor(conditioning_scale) if isinstance(conditioning_scale, torch.Tensor): - if conditioning_scale.size(0) != len(self.transformer.config.vace_layers): + if conditioning_scale.size(0) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}." + f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}." ) conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) @@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) conditioning_latents = conditioning_latents.to(transformer_dtype) - num_channels_latents = self.transformer.config.in_channels + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py index f99863c880..fe078c0deb 100644 --- a/tests/pipelines/wan/test_wan_vace.py +++ b/tests/pipelines/wan/test_wan_vace.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest import numpy as np @@ -19,9 +20,15 @@ import torch from PIL import Image from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, + WanVACEPipeline, + WanVACETransformer3DModel, +) -from ...testing_utils import enable_full_determinism +from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) def test_save_load_float16(self): pass + + def test_inference_with_only_transformer(self): + components = self.get_dummy_components() + components["transformer_2"] = None + components["boundary_ratio"] = 0.0 + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + video = pipe(**inputs).frames[0] + assert video.shape == (17, 3, 16, 16) + + def test_inference_with_only_transformer_2(self): + components = self.get_dummy_components() + components["transformer_2"] = components["transformer"] + components["transformer"] = None + + # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler + # because starting timestep t == 1000 == boundary_timestep + components["scheduler"] = UniPCMultistepScheduler( + prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0 + ) + + components["boundary_ratio"] = 1.0 + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + video = pipe(**inputs).frames[0] + assert video.shape == (17, 3, 16, 16) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = ["transformer"] + + components = self.get_dummy_components() + components["transformer_2"] = components["transformer"] + # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler + # because starting timestep t == 1000 == boundary_timestep + components["scheduler"] = UniPCMultistepScheduler( + prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0 + ) + for component in optional_component: + components[component] = None + + components["boundary_ratio"] = 1.0 + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for component in optional_component: + assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading." + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference" From dc622a95d0c8dcd1a0badbe6cf6e61ef4eca3917 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 28 Oct 2025 11:59:20 +0800 Subject: [PATCH 09/16] fix crash if tiling mode is enabled (#12521) * fix crash in tiling mode is enabled Signed-off-by: Wang, Yi A * fmt Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Sayak Paul --- .../models/autoencoders/autoencoder_kl_wan.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index dae26f8086..f8bdfeb755 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio - - blend_height = self.tile_sample_min_height - self.tile_sample_stride_height - blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. @@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo self._conv_idx = [0] tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) time.append(decoded) row.append(torch.cat(time, dim=2)) rows.append(row) @@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) - result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) - dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + if not return_dict: return (dec,) return DecoderOutput(sample=dec) From 40528e9ae7d56740c00d838299198d34111717bb Mon Sep 17 00:00:00 2001 From: Meatfucker <74834323+Meatfucker@users.noreply.github.com> Date: Tue, 28 Oct 2025 01:54:24 -0400 Subject: [PATCH 10/16] Fix typos in kandinsky5 docs (#12552) Update kandinsky5.md Fix typos --- docs/source/en/api/pipelines/kandinsky5.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/kandinsky5.md b/docs/source/en/api/pipelines/kandinsky5.md index cb1c119f80..a98a0826b7 100644 --- a/docs/source/en/api/pipelines/kandinsky5.md +++ b/docs/source/en/api/pipelines/kandinsky5.md @@ -92,7 +92,7 @@ pipe = pipe.to("cuda") pipe.transformer.set_attention_backend( "flex" -) # <--- Sett attention bakend to Flex +) # <--- Set attention backend to Flex pipe.transformer.compile( mode="max-autotune-no-cudagraphs", dynamic=True @@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9) ``` ### Diffusion Distilled model -**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): +**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```): ```python model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" @@ -146,4 +146,4 @@ export_to_video(output, "output.mp4", fps=24, quality=9) howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}}, year = 2025 } -``` \ No newline at end of file +``` From 55d49d4379007740af20629bb61aba9546c6b053 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 28 Oct 2025 13:29:51 +0530 Subject: [PATCH 11/16] [ci] don't run sana layerwise casting tests in CI. (#12551) * don't run sana layerwise casting tests in CI. * up --- tests/lora/test_lora_layers_sana.py | 6 +++++- .../models/autoencoders/test_models_autoencoder_dc.py | 10 +++++----- tests/pipelines/sana/test_sana.py | 5 +++++ tests/pipelines/sana/test_sana_controlnet.py | 9 +++++---- tests/pipelines/sana/test_sana_sprint.py | 9 +++++---- tests/pipelines/sana/test_sana_sprint_img2img.py | 9 +++++---- tests/testing_utils.py | 2 ++ 7 files changed, 32 insertions(+), 18 deletions(-) diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 3cdb28de75..a860b7b44f 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -20,7 +20,7 @@ from transformers import Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel -from ..testing_utils import floats_tensor, require_peft_backend +from ..testing_utils import IS_GITHUB_ACTIONS, floats_tensor, require_peft_backend sys.path.append(".") @@ -136,3 +136,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_save_load(self): pass + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference_denoiser(self): + return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py index a6912f3eba..d34001e7b9 100644 --- a/tests/models/autoencoders/test_models_autoencoder_dc.py +++ b/tests/models/autoencoders/test_models_autoencoder_dc.py @@ -17,11 +17,7 @@ import unittest from diffusers import AutoencoderDC -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) +from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device from ..test_modeling_common import ModelTesterMixin from .testing_utils import AutoencoderTesterMixin @@ -82,3 +78,7 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test init_dict = self.get_autoencoder_dc_config() inputs_dict = self.dummy_input return init_dict, inputs_dict + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index 34ea3079b1..f23303c966 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -23,6 +23,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from ...testing_utils import ( + IS_GITHUB_ACTIONS, backend_empty_cache, enable_full_determinism, require_torch_accelerator, @@ -304,6 +305,10 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # Requires higher tolerance as model seems very sensitive to dtype super().test_float16_inference(expected_max_diff=0.08) + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() + @slow @require_torch_accelerator diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py index 043e276fcb..df14d935ed 100644 --- a/tests/pipelines/sana/test_sana_controlnet.py +++ b/tests/pipelines/sana/test_sana_controlnet.py @@ -28,10 +28,7 @@ from diffusers import ( ) from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) +from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -326,3 +323,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def test_float16_inference(self): # Requires higher tolerance as model seems very sensitive to dtype super().test_float16_inference(expected_max_diff=0.08) + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py index fee2304dce..0d45205ea8 100644 --- a/tests/pipelines/sana/test_sana_sprint.py +++ b/tests/pipelines/sana/test_sana_sprint.py @@ -21,10 +21,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) +from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -300,3 +297,7 @@ class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def test_float16_inference(self): # Requires higher tolerance as model seems very sensitive to dtype super().test_float16_inference(expected_max_diff=0.08) + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py index c218abb8e9..5de5c7f446 100644 --- a/tests/pipelines/sana/test_sana_sprint_img2img.py +++ b/tests/pipelines/sana/test_sana_sprint_img2img.py @@ -22,10 +22,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) +from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, @@ -312,3 +309,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) def test_float16_inference(self): # Requires higher tolerance as model seems very sensitive to dtype super().test_float16_inference(expected_max_diff=0.08) + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7f849219c1..951ba41280 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -63,6 +63,8 @@ else: IS_CUDA_SYSTEM = False IS_XPU_SYSTEM = False +IS_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" and os.getenv("DIFFUSERS_IS_CI") == "yes" + global_rng = random.Random() logger = get_logger(__name__) From 84e16575e4c5e90b6b49301cfa162ced4cf478d2 Mon Sep 17 00:00:00 2001 From: galbria <158810732+galbria@users.noreply.github.com> Date: Tue, 28 Oct 2025 12:57:48 +0200 Subject: [PATCH 12/16] Bria fibo (#12545) * Bria FIBO pipeline * style fixs * fix CR * Refactor BriaFibo classes and update pipeline parameters - Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents. - Modified the _unpack_latents method in BriaFiboPipeline to improve clarity. - Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching. - Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests. * edit the docs of FIBO * Remove unused BriaFibo imports and update CPU offload method in BriaFiboPipeline * Refactor FIBO classes to BriaFibo naming convention - Updated class names from FIBO to BriaFibo for consistency across the module. - Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming. - Ensured all references in the BriaFiboTransformer2DModel are updated accordingly. * Add BriaFiboTransformer2DModel import to transformers module * Remove unused BriaFibo imports from modular pipelines and add BriaFiboTransformer2DModel and BriaFiboPipeline classes to dummy objects for enhanced compatibility with torch and transformers. * Update BriaFibo classes with copied documentation and fix import typo in pipeline module - Added documentation comments indicating the source of copied code in BriaFiboTransformerBlock and _pack_latents methods. - Corrected the import statement for BriaFiboPipeline in the pipelines module. * Remove unused BriaFibo imports from __init__.py to streamline modular pipelines. * Refactor documentation comments in BriaFibo classes to indicate inspiration from existing implementations - Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to reflect that the code is inspired by other modules rather than copied. - Enhanced clarity on the origins of the methods to maintain proper attribution. * change Inspired by to Based on * add reference link and fix trailing whitespace * Add BriaFiboTransformer2DModel documentation and update comments in BriaFibo classes - Introduced a new documentation file for BriaFiboTransformer2DModel. - Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to clarify the origins of the code, indicating copied sources for better attribution. --------- Co-authored-by: sayakpaul --- docs/source/en/_toctree.yml | 4 + .../en/api/models/transformer_bria_fibo.md | 19 + docs/source/en/api/pipelines/bria_fibo.md | 45 + src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_bria_fibo.py | 655 ++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/bria_fibo/__init__.py | 48 + .../pipelines/bria_fibo/pipeline_bria_fibo.py | 838 ++++++++++++++++++ .../pipelines/bria_fibo/pipeline_output.py | 21 + src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_bria_fibo.py | 89 ++ tests/pipelines/bria_fibo/__init__.py | 0 .../bria_fibo/test_pipeline_bria_fibo.py | 139 +++ 16 files changed, 1897 insertions(+) create mode 100644 docs/source/en/api/models/transformer_bria_fibo.md create mode 100644 docs/source/en/api/pipelines/bria_fibo.md create mode 100644 src/diffusers/models/transformers/transformer_bria_fibo.py create mode 100644 src/diffusers/pipelines/bria_fibo/__init__.py create mode 100644 src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py create mode 100644 src/diffusers/pipelines/bria_fibo/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_bria_fibo.py create mode 100644 tests/pipelines/bria_fibo/__init__.py create mode 100644 tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 44870f680e..8103c01643 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -323,6 +323,8 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel + - local: api/models/transformer_bria_fibo + title: BriaFiboTransformer2DModel - local: api/models/bria_transformer title: BriaTransformer2DModel - local: api/models/chroma_transformer @@ -469,6 +471,8 @@ title: BLIP-Diffusion - local: api/pipelines/bria_3_2 title: Bria 3.2 + - local: api/pipelines/bria_fibo + title: Bria Fibo - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 diff --git a/docs/source/en/api/models/transformer_bria_fibo.md b/docs/source/en/api/models/transformer_bria_fibo.md new file mode 100644 index 0000000000..5691746ccd --- /dev/null +++ b/docs/source/en/api/models/transformer_bria_fibo.md @@ -0,0 +1,19 @@ + + +# BriaFiboTransformer2DModel + +A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO) + +## BriaFiboTransformer2DModel + +[[autodoc]] BriaFiboTransformer2DModel diff --git a/docs/source/en/api/pipelines/bria_fibo.md b/docs/source/en/api/pipelines/bria_fibo.md new file mode 100644 index 0000000000..96cad10dda --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo.md @@ -0,0 +1,45 @@ + + +# Bria Fibo + +Text-to-image models have mastered imagination - but not control. FIBO changes that. + +FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs. + +With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control. + +FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts. +you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt. + +its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results. + +you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO). + + +## Usage + +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaPipeline + +[[autodoc]] BriaPipeline + - all + - __call__ + diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2a8b589846..94104667b5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -198,6 +198,7 @@ else: "AutoencoderOobleck", "AutoencoderTiny", "AutoModel", + "BriaFiboTransformer2DModel", "BriaTransformer2DModel", "CacheMixin", "ChromaTransformer2DModel", @@ -430,6 +431,7 @@ else: "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", "ChromaPipeline", @@ -901,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderOobleck, AutoencoderTiny, AutoModel, + BriaFiboTransformer2DModel, BriaTransformer2DModel, CacheMixin, ChromaTransformer2DModel, @@ -1103,6 +1106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, ChromaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 241ccf7b78..e3b2974641 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -84,6 +84,7 @@ if is_torch_available(): _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] + _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] @@ -174,6 +175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .transformers import ( AllegroTransformer3DModel, AuraFlowTransformer2DModel, + BriaFiboTransformer2DModel, BriaTransformer2DModel, ChromaTransformer2DModel, CogVideoXTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6c66131dea..2fe1159eec 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ if is_torch_available(): from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel from .transformer_bria import BriaTransformer2DModel + from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview4 import CogView4Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py new file mode 100644 index 0000000000..09f7961932 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -0,0 +1,655 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_bria import BriaAttnProcessor +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention +class BriaFiboAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "BriaFiboAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py +class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = BriaFiboAttnProcessor + _available_processors = [BriaFiboAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class BriaFiboEmbedND(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +@maybe_allow_in_graph +class BriaFiboSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = BriaAttnProcessor() + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +class BriaFiboTextProjection(nn.Module): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +@maybe_allow_in_graph +# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock +class BriaFiboTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = BriaFiboAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=BriaFiboAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class BriaFiboTimesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.time_theta = time_theta + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.time_theta, + ) + return t_emb + + +class BriaFiboTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, time_theta): + super().__init__() + + self.time_proj = BriaFiboTimesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta + ) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) + return timesteps_emb + + +class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + ... + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = None, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + rope_theta=10000, + time_theta=10000, + text_encoder_dim: int = 2048, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + + self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + + if guidance_embeds: + self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BriaFiboTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + BriaFiboSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + caption_projection = [ + BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2) + for i in range(self.config.num_layers + self.config.num_single_layers) + ] + self.caption_projection = nn.ModuleList(caption_projection) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + text_encoder_layers: list = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) + else: + guidance = None + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + + if guidance: + temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if len(txt_ids.shape) == 3: + txt_ids = txt_ids[0] + + if len(img_ids.shape) == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + new_text_encoder_layers = [] + for i, text_encoder_layer in enumerate(text_encoder_layers): + text_encoder_layer = self.caption_projection[i](text_encoder_layer) + new_text_encoder_layers.append(text_encoder_layer) + text_encoder_layers = new_text_encoder_layers + + block_id = 0 + for index_block, block in enumerate(self.transformer_blocks): + current_text_encoder_layer = text_encoder_layers[block_id] + encoder_hidden_states = torch.cat( + [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1 + ) + block_id += 1 + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + current_text_encoder_layer = text_encoder_layers[block_id] + encoder_hidden_states = torch.cat( + [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1 + ) + block_id += 1 + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...] + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 65e8b2469e..db357669b6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,6 +128,7 @@ else: "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline + from .bria_fibo import BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py new file mode 100644 index 0000000000..206a463b39 --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria_fibo import BriaFiboPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py new file mode 100644 index 0000000000..85d29029e6 --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -0,0 +1,838 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + + pipe = BriaFiboPipeline.from_pretrained( + "briaai/FIBO", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + pipe.enable_model_cpu_offload() + + with torch.inference_mode(): + # 1. Create a prompt to generate an initial image + output = vlm_pipe(prompt="a beautiful dog") + json_prompt_generate = output.values["json_prompt"] + + # Generate the image from the structured json prompt + results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5) + results_generate.images[0].save("image_generate.png") + ``` +""" + + +class BriaFiboPipeline(DiffusionPipeline): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # 5. Prepare latent variables + + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_output.py b/src/diffusers/pipelines/bria_fibo/pipeline_output.py new file mode 100644 index 0000000000..f459185a2c --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaFiboPipelineOutput(BaseOutput): + """ + Output class for BriaFibo pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6ef6b8b0e9..3c426d5039 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -588,6 +588,21 @@ class AutoModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class BriaFiboTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class BriaTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7c1dcba9c7..20575ff229 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -482,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py new file mode 100644 index 0000000000..f859f4608b --- /dev/null +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import BriaFiboTransformer2DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = BriaFiboTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.8, 0.7, 0.7] + + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_latent_channels = 48 + num_image_channels = 3 + height = width = 16 + sequence_length = 32 + embedding_dim = 64 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]], + } + + @property + def input_shape(self): + return (16, 16) + + @property + def output_shape(self): + return (256, 48) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 1, + "in_channels": 48, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 64, + "text_encoder_dim": 32, + "pooled_projection_dim": None, + "axes_dims_rope": [0, 4, 4], + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"BriaFiboTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/bria_fibo/__init__.py b/tests/pipelines/bria_fibo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py new file mode 100644 index 0000000000..76b41114f8 --- /dev/null +++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py @@ -0,0 +1,139 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=160, + decoder_base_dim=256, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "{'text': 'A painting of a squirrel eating a burger'}", + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "output_type": "np", + } + return inputs + + @unittest.skip(reason="will not be supported due to dim-fusion") + def test_encode_prompt_works_in_isolation(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (64, 64), (32, 64)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) From 9f3c0fdcd859905c2c13ec47f10eb0250d2576ac Mon Sep 17 00:00:00 2001 From: Pavle Padjin Date: Thu, 30 Oct 2025 04:09:40 +0100 Subject: [PATCH 13/16] Avoiding graph break by changing the way we infer dtype in vae.decoder (#12512) * Changing the way we infer dtype to avoid force evaluation of lazy tensors * changing way to infer dtype to ensure type consistency * more robust infering of dtype * removing the upscale dtype entirely --- src/diffusers/models/autoencoders/vae.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 7b17196125..9c6031a988 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -286,11 +286,9 @@ class Decoder(nn.Module): sample = self.conv_in(sample) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: # middle sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) - sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: @@ -298,7 +296,6 @@ class Decoder(nn.Module): else: # middle sample = self.mid_block(sample, latent_embeds) - sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: From df8dd778177c3d2272f74cbebd880d7abd9f5ec9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 31 Oct 2025 00:14:24 +0530 Subject: [PATCH 14/16] [Modular] Fix for custom block kwargs (#12561) update --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index bf067555a8..55c261ab2f 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -335,7 +335,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ) expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) block_kwargs = { - name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs } return block_cls(**block_kwargs) From d54622c2679d700b425ad61abce9b80fc36212c0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 31 Oct 2025 13:47:02 +0530 Subject: [PATCH 15/16] [Modular] Allow custom blocks to be saved to `local_dir` (#12381) update Co-authored-by: YiYi Xu --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- src/diffusers/utils/dynamic_modules_utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 55c261ab2f..ef1673c057 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -305,6 +305,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): "cache_dir", "force_download", "local_files_only", + "local_dir", "proxies", "resume_download", "revision", @@ -331,7 +332,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): module_file=module_file, class_name=class_name, **hub_kwargs, - **kwargs, ) expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) block_kwargs = { diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 627b1e0604..b2ef5a29e0 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -254,6 +254,7 @@ def get_cached_module_file( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + local_dir: Optional[str] = None, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached @@ -332,6 +333,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" @@ -355,6 +357,7 @@ def get_cached_module_file( force_download=force_download, proxies=proxies, local_files_only=local_files_only, + local_dir=local_dir, token=token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) @@ -415,6 +418,7 @@ def get_cached_module_file( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return os.path.join(full_submodule, module_file) @@ -431,7 +435,7 @@ def get_class_from_dynamic_module( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, - **kwargs, + local_dir: Optional[str] = None, ): """ Extracts a class from a module file, present in the local folder or repository of a model. @@ -496,5 +500,6 @@ def get_class_from_dynamic_module( token=token, revision=revision, local_files_only=local_files_only, + local_dir=local_dir, ) return get_class_in_module(class_name, final_module) From 051c8a1c0f5c393a447bef18081fdf94c2a3ab9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Sch=C3=B6ller?= Date: Fri, 31 Oct 2025 21:25:13 +0100 Subject: [PATCH 16/16] Fix Stable Diffusion 3.x pooled prompt embedding with multiple images (#12306) --- .../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py | 2 +- .../pipeline_stable_diffusion_3_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/pag/pipeline_pag_sd_3.py | 2 +- src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py | 2 +- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index f67a0e2112..d605eac1f2 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 68984da4dc..9d0158c6b6 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index bc281428e2..941b675099 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 22a8dac238..f40dd52fc2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 3b7b26dc63..660d9801df 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index db047f1992..9b11bc8781 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index c95fa530c8..b947cbff09 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds