From b5309683cb6753e2111be9a8204f90a550c3fcb6 Mon Sep 17 00:00:00 2001 From: Miguel Martin Date: Thu, 18 Dec 2025 16:08:18 -0800 Subject: [PATCH 1/8] Cosmos Predict2.5 Base: inference pipeline, scheduler & chkpt conversion (#12852) * cosmos predict2.5 base: convert chkpt & pipeline - New scheduler: scheduling_flow_unipc_multistep.py - Changes to TransformerCosmos for text embeddings via crossattn_proj * scheduler cleanup * simplify inference pipeline * cleanup scheduler + tests * Basic tests for flow unipc * working b2b inference * Rename everything * Tests for pipeline present, but not working (predict2 also not working) * docstring update * wrapper pipelines + make style * remove unnecessary files * UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True * use UniPCMultistepScheduler + fix tests for pipeline * Remove FlowUniPCMultistepScheduler * UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True * num_inference_steps=36 due to bug in scheduler used by predict2.5 * Address comments * make style + make fix-copies * fix tests + remove references to old pipelines * address comments * add revision in from_pretrained call * fix tests --- docs/source/en/api/pipelines/cosmos.md | 6 + scripts/convert_cosmos_to_diffusers.py | 135 ++- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_cosmos.py | 13 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/cosmos/__init__.py | 6 + .../cosmos/pipeline_cosmos2_5_predict.py | 847 ++++++++++++++++++ .../schedulers/scheduling_unipc_multistep.py | 9 +- .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/cosmos/cosmos_guardrail.py | 11 +- .../cosmos/test_cosmos2_5_predict.py | 337 +++++++ tests/schedulers/test_scheduler_unipc.py | 29 + 12 files changed, 1398 insertions(+), 14 deletions(-) create mode 100644 src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py create mode 100644 tests/pipelines/cosmos/test_cosmos2_5_predict.py diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md index fb9453480e..60ecce6603 100644 --- a/docs/source/en/api/pipelines/cosmos.md +++ b/docs/source/en/api/pipelines/cosmos.md @@ -70,6 +70,12 @@ output.save("output.png") - all - __call__ +## Cosmos2_5_PredictBasePipeline + +[[autodoc]] Cosmos2_5_PredictBasePipeline + - all + - __call__ + ## CosmosPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index 6f6563ad64..6e70f8cc05 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -1,11 +1,55 @@ +""" +# Cosmos 2 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2-2B-Text2Image +``` + +convert checkpoint +```bash +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_ckpt_path $transformer_ckpt_path \ + --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \ + --text_encoder_path google-t5/t5-11b \ + --tokenizer_path google-t5/t5-11b \ + --vae_type wan2.1 \ + --output_path converted/cosmos-p2-t2i-2b \ + --save_pipeline +``` + +# Cosmos 2.5 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2.5-2B +``` + +Convert checkpoint +```bash +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/cosmos-p2.5-base-2b \ + --save_pipeline +``` + +""" + import argparse import pathlib +import sys from typing import Any, Dict import torch from accelerate import init_empty_weights from huggingface_hub import snapshot_download -from transformers import T5EncoderModel, T5TokenizerFast +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast from diffusers import ( AutoencoderKLCosmos, @@ -17,7 +61,9 @@ from diffusers import ( CosmosVideoToWorldPipeline, EDMEulerScheduler, FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, ) +from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -233,6 +279,25 @@ TRANSFORMER_CONFIGS = { "concat_padding_mask": True, "extra_pos_embed_type": None, }, + "Cosmos-2.5-Predict-Base-2B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 28, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + # NOTE: source config has pos_emb_learnable: 'True' - but params are missing + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, } VAE_KEYS_RENAME_DICT = { @@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo elif "Cosmos-2.0" in transformer_type: TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 + elif "Cosmos-2.5" in transformer_type: + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 else: assert False @@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo new_key = new_key.removeprefix(PREFIX_KEY) for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) + print(key, "->", new_key, flush=True) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): @@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo continue handler_fn_inplace(key, original_state_dict) + expected_keys = set(transformer.state_dict().keys()) + mapped_keys = set(original_state_dict.keys()) + missing_keys = expected_keys - mapped_keys + unexpected_keys = mapped_keys - expected_keys + if missing_keys: + print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr) + for k in missing_keys: + print(k) + sys.exit(1) + if unexpected_keys: + print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr) + for k in unexpected_keys: + print(k) + sys.exit(2) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer @@ -444,6 +528,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") +def save_pipeline_cosmos2_5(args, transformer, vae): + text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B" + tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct" + + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + text_encoder_path, torch_dtype="auto", device_map="cpu" + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + scheduler = UniPCMultistepScheduler( + use_karras_sigmas=True, + use_flow_sigmas=True, + prediction_type="flow_prediction", + sigma_max=200.0, + sigma_min=0.01, + ) + + pipe = Cosmos2_5_PredictBasePipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vae=vae, + scheduler=scheduler, + safety_checker=lambda *args, **kwargs: None, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) @@ -451,10 +563,10 @@ def get_args(): "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument( - "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE" + "--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE" ) - parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b") - parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b") + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--tokenizer_path", type=str, default=None) parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") @@ -477,8 +589,6 @@ if __name__ == "__main__": if args.save_pipeline: assert args.transformer_ckpt_path is not None assert args.vae_type is not None - assert args.text_encoder_path is not None - assert args.tokenizer_path is not None if args.transformer_ckpt_path is not None: weights_only = "Cosmos-1.0" in args.transformer_type @@ -490,17 +600,26 @@ if __name__ == "__main__": if args.vae_type is not None: if "Cosmos-1.0" in args.transformer_type: vae = convert_vae(args.vae_type) - else: + elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type: vae = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 ) + else: + raise AssertionError(f"{args.transformer_type} not supported") + if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.save_pipeline: if "Cosmos-1.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_1_0(args, transformer, vae) elif "Cosmos-2.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_2_0(args, transformer, vae) + elif "Cosmos-2.5" in args.transformer_type: + save_pipeline_cosmos2_5(args, transformer, vae) else: - assert False + raise AssertionError(f"{args.transformer_type} not supported") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 03ecaf6bc1..6aac3feffd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -463,6 +463,7 @@ else: "CogView4ControlPipeline", "CogView4Pipeline", "ConsisIDPipeline", + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "CosmosTextToWorldPipeline", @@ -1175,6 +1176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: CogView4ControlPipeline, CogView4Pipeline, ConsisIDPipeline, + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 373b470ae3..2b0c266707 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -439,6 +439,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0), concat_padding_mask: bool = True, extra_pos_embed_type: Optional[str] = "learnable", + use_crossattn_projection: bool = False, + crossattn_proj_in_channels: int = 1024, + encoder_hidden_states_channels: int = 1024, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -485,6 +488,12 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False ) + if self.config.use_crossattn_projection: + self.crossattn_proj = nn.Sequential( + nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True), + nn.GELU(), + ) + self.gradient_checkpointing = False def forward( @@ -524,6 +533,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w + hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C] @@ -546,6 +556,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): else: assert False + if self.config.use_crossattn_projection: + encoder_hidden_states = self.crossattn_proj(encoder_hidden_states) + # 5. Transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 04ec6b5cd8..e8faf868e7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -165,6 +165,7 @@ else: _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["cosmos"] = [ + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", @@ -622,6 +623,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLControlNetXSPipeline, ) from .cosmos import ( + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 2833c89abd..944f165531 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -22,6 +22,9 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["pipeline_cosmos2_5_predict"] = [ + "Cosmos2_5_PredictBasePipeline", + ] _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] @@ -35,6 +38,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .pipeline_cosmos2_5_predict import ( + Cosmos2_5_PredictBasePipeline, + ) from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py new file mode 100644 index 0000000000..6564b59373 --- /dev/null +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -0,0 +1,847 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2_5_PredictBasePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "nvidia/Cosmos-Predict2.5-2B" + >>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + ... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Common negative prompt reused across modes. + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + + >>> # Text2World: generate a 93-frame world video from text only. + >>> prompt = ( + ... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights " + ... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh " + ... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet " + ... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. " + ... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow " + ... "advance of traffic through the frosty city corridor." + ... ) + >>> video = pipe( + ... image=None, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "text2world.mp4", fps=16) + + >>> # Image2World: condition on a single image and generate a 93-frame world video. + >>> prompt = ( + ... "A high-definition video captures the precision of robotic welding in an industrial setting. " + ... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. " + ... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid " + ... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring " + ... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a " + ... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video " + ... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. " + ... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. " + ... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with " + ... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation." + ... ) + >>> image = load_image( + ... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg" + ... ) + >>> video = pipe( + ... image=image, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> # export_to_video(video, "image2world.mp4", fps=16) + + >>> # Video2World: condition on an input clip and predict a 93-frame world video. + >>> prompt = ( + ... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles " + ... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the " + ... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green " + ... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. " + ... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along " + ... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame " + ... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet " + ... "steady pace of the construction activity." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4" + ... ) + >>> video = pipe( + ... image=None, + ... video=input_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "video2world.mp4", fps=16) + + >>> # To produce an image instead of a world (video) clip, set num_frames=1 and + >>> # save the first frame: pipe(..., num_frames=1).frames[0][0]. + ``` +""" + + +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): + r""" + Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: List[PipelineImageInput] | None = None, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + num_frames_in = len(video) + + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # pad with last frame (for video2world) + num_frames_out = num_frames + if video.shape[2] < num_frames_out: + n_pad_frames = num_frames_out - num_frames_in + last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) + + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + noise_pred_neg = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 689c6a0635..5ea56b300b 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -217,6 +217,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential"] = "exponential", + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -350,7 +352,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.use_flow_sigmas: + sigmas = sigmas / (sigmas + 1) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + else: + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 74a4146bfd..4e1eae211c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -767,6 +767,21 @@ class ConsisIDPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Cosmos2TextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/cosmos/cosmos_guardrail.py b/tests/pipelines/cosmos/cosmos_guardrail.py index 4de14fbaaf..c9ef597fdb 100644 --- a/tests/pipelines/cosmos/cosmos_guardrail.py +++ b/tests/pipelines/cosmos/cosmos_guardrail.py @@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin): def __init__(self) -> None: super().__init__() - self._dtype = torch.float32 + self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False) def check_text_safety(self, prompt: str) -> bool: return True @@ -35,13 +35,14 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin): def check_video_safety(self, frames: np.ndarray) -> np.ndarray: return frames - def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None: - self._dtype = dtype + def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None): + module = super().to(device=device, dtype=dtype) + return module @property def device(self) -> torch.device: - return None + return self._device_tracker.device @property def dtype(self) -> torch.dtype: - return self._dtype + return self._device_tracker.dtype diff --git a/tests/pipelines/cosmos/test_cosmos2_5_predict.py b/tests/pipelines/cosmos/test_cosmos2_5_predict.py new file mode 100644 index 0000000000..54d4edb485 --- /dev/null +++ b/tests/pipelines/cosmos/test_cosmos2_5_predict.py @@ -0,0 +1,337 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +import tempfile +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLWan, + Cosmos2_5_PredictBasePipeline, + CosmosTransformer3DModel, + UniPCMultistepScheduler, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np +from .cosmos_guardrail import DummyCosmosSafetyChecker + + +enable_full_determinism() + + +class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline): + @staticmethod + def from_pretrained(*args, **kwargs): + if "safety_checker" not in kwargs or kwargs["safety_checker"] is None: + safety_checker = DummyCosmosSafetyChecker() + device_map = kwargs.get("device_map", "cpu") + torch_dtype = kwargs.get("torch_dtype") + if device_map is not None or torch_dtype is not None: + safety_checker = safety_checker.to(device_map, dtype=torch_dtype) + kwargs["safety_checker"] = safety_checker + return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs) + + +class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Cosmos2_5_PredictBaseWrapper + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=16 + 1, + out_channels=16, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=32, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(2.0, 1.0, 1.0), + concat_padding_mask=True, + extra_pos_embed_type="learnable", + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": DummyCosmosSafetyChecker(), + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "num_frames": 3, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))} + pipe = self.pipeline_class(**init_components) + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not getattr(self, "test_attention_slicing", True): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + self.pipeline_class._optional_components.remove("safety_checker") + super().test_save_load_optional_components(expected_max_difference=expected_max_difference) + self.pipeline_class._optional_components.append("safety_checker") + + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name + for component_name, component in pipe.components.items() + if isinstance(component, torch.nn.Module) + ] + model_components.remove("safety_checker") + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname, safe_serialization=False) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained( + tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict + ) + + for name, component in loaded_pipe.components.items(): + if name == "safety_checker": + continue + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + + @unittest.skip( + "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in " + "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is " + "too large and slow to run on CI." + ) + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 197c831cb0..ac7e1d3f88 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -399,3 +399,32 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest): def test_exponential_sigmas(self): self.check_over_configs(use_exponential_sigmas=True) + + def test_flow_and_karras_sigmas(self): + self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True) + + def test_flow_and_karras_sigmas_values(self): + num_train_timesteps = 1000 + num_inference_steps = 5 + scheduler = UniPCMultistepScheduler( + sigma_min=0.01, + sigma_max=200.0, + use_flow_sigmas=True, + use_karras_sigmas=True, + num_train_timesteps=num_train_timesteps, + ) + scheduler.set_timesteps(num_inference_steps=num_inference_steps) + + expected_sigmas = [ + 0.9950248599052429, + 0.9787454605102539, + 0.8774884343147278, + 0.3604971766471863, + 0.009900986216962337, + 0.0, # 0 appended as default + ] + expected_sigmas = torch.tensor(expected_sigmas) + expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64) + expected_timesteps = expected_timesteps[0:-1] + self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas)) + self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps)) From f7753b1bc8b4b3b97dc7f71d51ccb3a281b17b48 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 18 Dec 2025 19:25:20 -1000 Subject: [PATCH 2/8] more update in modular (#12560) * move node registry to mellon * up * fix * modula rpipeline update: filter out none for input_names, fix default blocks for pipe.init() and allow user pass additional kwargs_type in a dict * qwen modular refactor, unpack before decode * update mellon node config, adding* to required_inputs and required_model_inputs * modularpipeline.from_pretrained: error out if no config found * add a component_names property to modular blocks to be consistent! * flux image_encoder -> vae_encoder * controlnet_bundle * refator MellonNodeConfig MellonPipelineConfig * refactor & simplify mellon utils * vae_image_encoder -> vae_encoder * mellon config save keep key order * style + copies * add kwargs input for zimage --- .../modular_pipelines/flux/modular_blocks.py | 4 +- .../modular_pipelines/mellon_node_utils.py | 1141 +++++++---------- .../modular_pipelines/modular_pipeline.py | 25 +- .../modular_pipelines/qwenimage/decoders.py | 54 +- .../qwenimage/modular_blocks.py | 29 +- .../modular_pipelines/qwenimage/node_utils.py | 95 -- .../stable_diffusion_xl/node_utils.py | 99 -- .../modular_pipelines/z_image/denoise.py | 4 + .../z_image/modular_blocks.py | 8 +- 9 files changed, 585 insertions(+), 874 deletions(-) delete mode 100644 src/diffusers/modular_pipelines/qwenimage/node_utils.py delete mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index a80bc2a5f7..bd9b2d1b40 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks): AUTO_BLOCKS = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxAutoVaeEncoderStep()), + ("vae_encoder", FluxAutoVaeEncoderStep()), ("denoise", FluxCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] @@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict( AUTO_BLOCKS_KONTEXT = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxKontextAutoVaeEncoderStep()), + ("vae_encoder", FluxKontextAutoVaeEncoderStep()), ("denoise", FluxKontextCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index a405aebee2..4f142a453f 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -4,315 +4,31 @@ import os # Simple typed wrapper for parameter overrides from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import ( EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, - validate_hf_hub_args, ) -from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash -from .modular_pipeline import ModularPipelineBlocks +from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT logger = logging.getLogger(__name__) -SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"} - - -# Mellon Input Parameters (runtime parameters, not models) -MELLON_INPUT_PARAMS = { - # controlnet - "control_image": { - "label": "Control Image", - "type": "image", - "display": "input", - }, - "controlnet_conditioning_scale": { - "label": "Scale", - "type": "float", - "default": 0.5, - "min": 0, - "max": 1, - }, - "control_guidance_end": { - "label": "End", - "type": "float", - "default": 1.0, - "min": 0, - "max": 1, - }, - "control_guidance_start": { - "label": "Start", - "type": "float", - "default": 0.0, - "min": 0, - "max": 1, - }, - "controlnet": { - "label": "Controlnet", - "type": "custom_controlnet", - "display": "input", - }, - "embeddings": { - "label": "Text Embeddings", - "display": "input", - "type": "embeddings", - }, - "image": { - "label": "Image", - "type": "image", - "display": "input", - }, - "negative_prompt": { - "label": "Negative Prompt", - "type": "string", - "default": "", - "display": "textarea", - }, - "prompt": { - "label": "Prompt", - "type": "string", - "default": "", - "display": "textarea", - }, - "guidance_scale": { - "label": "Guidance Scale", - "type": "float", - "display": "slider", - "default": 5, - "min": 1.0, - "max": 30.0, - "step": 0.1, - }, - "height": { - "label": "Height", - "type": "int", - "default": 1024, - "min": 64, - "step": 8, - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "input", - "onChange": {False: ["height", "width"], True: ["strength"]}, - }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "input", - }, - "num_inference_steps": { - "label": "Steps", - "type": "int", - "display": "slider", - "default": 25, - "min": 1, - "max": 100, - }, - "seed": { - "label": "Seed", - "type": "int", - "display": "random", - "default": 0, - "min": 0, - "max": 4294967295, - }, - "strength": { - "label": "Strength", - "type": "float", - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - }, - "width": { - "label": "Width", - "type": "int", - "default": 1024, - "min": 64, - "step": 8, - }, - "ip_adapter": { - "label": "IP Adapter", - "type": "custom_ip_adapter", - "display": "input", - }, -} - -# Mellon Model Parameters (diffusers_auto_model types) -MELLON_MODEL_PARAMS = { - "scheduler": { - "label": "Scheduler", - "display": "input", - "type": "diffusers_auto_model", - }, - "text_encoders": { - "label": "Text Encoders", - "type": "diffusers_auto_models", - "display": "input", - }, - "unet": { - "label": "Unet", - "display": "input", - "type": "diffusers_auto_model", - "onSignal": { - "action": "signal", - "target": "guider", - }, - }, - "guider": { - "label": "Guider", - "display": "input", - "type": "custom_guider", - "onChange": {False: ["guidance_scale"], True: []}, - }, - "vae": { - "label": "VAE", - "display": "input", - "type": "diffusers_auto_model", - }, - "controlnet": { - "label": "Controlnet Model", - "type": "diffusers_auto_model", - "display": "input", - }, -} - -# Mellon Output Parameters (display = "output") -MELLON_OUTPUT_PARAMS = { - "embeddings": { - "label": "Text Embeddings", - "display": "output", - "type": "embeddings", - }, - "images": { - "label": "Images", - "type": "image", - "display": "output", - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "output", - }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "output", - }, - "latents_preview": { - "label": "Latents Preview", - "display": "output", - "type": "latent", - }, - "controlnet_out": { - "label": "Controlnet", - "display": "output", - "type": "controlnet", - }, -} - - -# Default param selections per supported node_type -# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. -NODE_TYPE_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet", - ], - "block_names": ["controlnet_vae_encoder"], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - "block_names": ["vae_encoder"], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - # optional image prompt input supported in embeddings node - "image", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - "block_names": ["text_encoder"], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - "block_names": ["decode"], - }, -} - - @dataclass(frozen=True) class MellonParam: + """ + Parameter definition for Mellon nodes. + + Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...", + label="...", type="..."). + """ + name: str label: str type: str @@ -326,122 +42,482 @@ class MellonParam: fieldOptions: Optional[Dict[str, Any]] = None onChange: Any = None onSignal: Any = None - _map_to_input: Any = None # the block input name this parameter maps to def to_dict(self) -> Dict[str, Any]: + """Convert to dict for Mellon schema, excluding None values and name.""" data = asdict(self) - return {k: v for k, v in data.items() if not k.startswith("_") and v is not None} - - -@dataclass -class MellonNodeConfig(PushToHubMixin): - """ - A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers. - - - - This is an experimental feature and is likely to change in the future. - - - """ - - inputs: List[Union[str, MellonParam]] - model_inputs: List[Union[str, MellonParam]] - outputs: List[Union[str, MellonParam]] - blocks_names: list[str] - node_type: str - config_name = "mellon_config.json" - - def __post_init__(self): - if isinstance(self.inputs, list): - self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS) - if isinstance(self.model_inputs, list): - self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS) - if isinstance(self.outputs, list): - self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS) - - @staticmethod - def _resolve_params_list( - params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]] - ) -> Dict[str, Dict[str, Any]]: - def _resolve_param( - param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]] - ) -> Tuple[str, Dict[str, Any]]: - if isinstance(param, str): - if param not in default_params_map: - raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead") - return param, default_params_map[param].copy() - elif isinstance(param, MellonParam): - param_dict = param.to_dict() - param_name = param_dict.pop("name") - return param_name, param_dict - else: - raise ValueError( - f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead" - ) - - resolved = {} - for p in params: - logger.info(f" Resolving param: {p}") - name, cfg = _resolve_param(p, default_map) - if name in resolved: - raise ValueError(f"Duplicate param '{name}'") - resolved[name] = cfg - return resolved + return {k: v for k, v in data.items() if v is not None and k != "name"} @classmethod - @validate_hf_hub_args - def load_mellon_config( + def image(cls) -> "MellonParam": + return cls(name="image", label="Image", type="image", display="input") + + @classmethod + def images(cls) -> "MellonParam": + return cls(name="images", label="Images", type="image", display="output") + + @classmethod + def control_image(cls, display: str = "input") -> "MellonParam": + return cls(name="control_image", label="Control Image", type="image", display=display) + + @classmethod + def latents(cls, display: str = "input") -> "MellonParam": + return cls(name="latents", label="Latents", type="latents", display=display) + + @classmethod + def image_latents(cls, display: str = "input") -> "MellonParam": + return cls(name="image_latents", label="Image Latents", type="latents", display=display) + + @classmethod + def image_latents_with_strength(cls) -> "MellonParam": + return cls( + name="image_latents", + label="Image Latents", + type="latents", + display="input", + onChange={"false": ["height", "width"], "true": ["strength"]}, + ) + + @classmethod + def latents_preview(cls) -> "MellonParam": + """ + `Latents Preview` is a special output parameter that is used to preview the latents in the UI. + """ + return cls(name="latents_preview", label="Latents Preview", type="latent", display="output") + + @classmethod + def embeddings(cls, display: str = "output") -> "MellonParam": + return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display) + + @classmethod + def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": + return cls( + name="controlnet_conditioning_scale", + label="Controlnet Conditioning Scale", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def control_guidance_start(cls, default: float = 0.0) -> "MellonParam": + return cls( + name="control_guidance_start", + label="Control Guidance Start", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def control_guidance_end(cls, default: float = 1.0) -> "MellonParam": + return cls( + name="control_guidance_end", + label="Control Guidance End", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def prompt(cls, default: str = "") -> "MellonParam": + return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea") + + @classmethod + def negative_prompt(cls, default: str = "") -> "MellonParam": + return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea") + + @classmethod + def strength(cls, default: float = 0.5) -> "MellonParam": + return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01) + + @classmethod + def guidance_scale(cls, default: float = 5.0) -> "MellonParam": + return cls( + name="guidance_scale", + label="Guidance Scale", + type="float", + display="slider", + default=default, + min=1.0, + max=30.0, + step=0.1, + ) + + @classmethod + def height(cls, default: int = 1024) -> "MellonParam": + return cls(name="height", label="Height", type="int", default=default, min=64, step=8) + + @classmethod + def width(cls, default: int = 1024) -> "MellonParam": + return cls(name="width", label="Width", type="int", default=default, min=64, step=8) + + @classmethod + def seed(cls, default: int = 0) -> "MellonParam": + return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random") + + @classmethod + def num_inference_steps(cls, default: int = 25) -> "MellonParam": + return cls( + name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider" + ) + + @classmethod + def vae(cls) -> "MellonParam": + """ + VAE model info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input") + + @classmethod + def unet(cls) -> "MellonParam": + """ + Denoising model (UNet/Transformer) info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input") + + @classmethod + def scheduler(cls) -> "MellonParam": + """ + Scheduler model info dict. + + Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual + scheduler. + """ + return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input") + + @classmethod + def controlnet(cls) -> "MellonParam": + """ + ControlNet model info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input") + + @classmethod + def text_encoders(cls) -> "MellonParam": + """ + Dict of text encoder model info dicts. + + Structure: { + 'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...}, + 'repo_id': '...' + } Use components.get_one(model_id) to retrieve each model. + """ + return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input") + + @classmethod + def controlnet_bundle(cls, display: str = "input") -> "MellonParam": + """ + ControlNet bundle containing model info and processed control inputs. + + Structure: { + 'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed + control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise + blocks + } + + Output from Controlnet node, input to Denoise node. + """ + return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display) + + @classmethod + def ip_adapter(cls) -> "MellonParam": + return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input") + + @classmethod + def guider(cls) -> "MellonParam": + return cls( + name="guider", + label="Guider", + type="custom_guider", + display="input", + onChange={False: ["guidance_scale"], True: []}, + ) + + @classmethod + def doc(cls) -> "MellonParam": + return cls(name="doc", label="Doc", type="string", display="output") + + +def mark_required(label: str, marker: str = " *") -> str: + """Add required marker to label if not already present.""" + if label.endswith(marker): + return label + return f"{label}{marker}" + + +def node_spec_to_mellon_dict(node_spec: Dict[str, Any], node_type: str) -> Dict[str, Any]: + """ + Convert a node spec dict into Mellon format. + + A node spec is how we define a Mellon diffusers node in code. This function converts it into the `params` map + format that Mellon UI expects. + + The `params` map is a dict where keys are parameter names and values are UI configuration: + ```python + {"seed": {"label": "Seed", "type": "int", "default": 0}} + ``` + + For Modular Mellon nodes, we need to distinguish: + - `inputs`: Pipeline inputs (e.g., seed, prompt, image) + - `model_inputs`: Model components (e.g., unet, vae, scheduler) + - `outputs`: Node outputs (e.g., latents, images) + + The node spec also includes: + - `required_inputs` / `required_model_inputs`: Which params are required (marked with *) + - `block_name`: The modular pipeline block this node corresponds to on backend + + We provide factory methods for common parameters (e.g., `MellonParam.seed()`, `MellonParam.unet()`) so you don't + have to manually specify all the UI configuration. + + Args: + node_spec: Dict with `inputs`, `model_inputs`, `outputs` (lists of MellonParam), + plus `required_inputs`, `required_model_inputs`, `block_name`. + node_type: The node type string (e.g., "denoise", "controlnet") + + Returns: + Dict with: + - `params`: Flat dict of all params in Mellon UI format + - `input_names`: List of input parameter names + - `model_input_names`: List of model input parameter names + - `output_names`: List of output parameter names + - `block_name`: The backend block name + - `node_type`: The node type + + Example: + ```python + node_spec = { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + } + + result = node_spec_to_mellon_dict(node_spec, "denoise") + # Returns: + # { + # "params": { + # "seed": {"label": "Seed", "type": "int", "default": 0}, + # "prompt": {"label": "Prompt *", "type": "string", "default": ""}, # * marks required + # "unet": {"label": "Denoise Model *", "type": "diffusers_auto_model", "display": "input"}, + # "latents": {"label": "Latents", "type": "latents", "display": "output"}, + # }, + # "input_names": ["seed", "prompt"], + # "model_input_names": ["unet"], + # "output_names": ["latents"], + # "block_name": "denoise", + # "node_type": "denoise", + # } + ``` + """ + params = {} + input_names = [] + model_input_names = [] + output_names = [] + + required_inputs = node_spec.get("required_inputs", []) + required_model_inputs = node_spec.get("required_model_inputs", []) + + # Process inputs + for p in node_spec.get("inputs", []): + param_dict = p.to_dict() + if p.name in required_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + input_names.append(p.name) + + # Process model_inputs + for p in node_spec.get("model_inputs", []): + param_dict = p.to_dict() + if p.name in required_model_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + model_input_names.append(p.name) + + # Process outputs + for p in node_spec.get("outputs", []): + params[p.name] = p.to_dict() + output_names.append(p.name) + + return { + "params": params, + "input_names": input_names, + "model_input_names": model_input_names, + "output_names": output_names, + "block_name": node_spec.get("block_name"), + "node_type": node_type, + } + + +class MellonPipelineConfig: + """ + Configuration for an entire Mellon pipeline containing multiple nodes. + + Accepts node specs as dicts with inputs/model_inputs/outputs lists of MellonParam, converts them to Mellon-ready + format, and handles save/load to Hub. + + Example: + ```python + config = MellonPipelineConfig( + node_specs={ + "denoise": { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + }, + "decoder": { + "inputs": [MellonParam.latents(display="input")], + "outputs": [MellonParam.images()], + "block_name": "decoder", + }, + }, + label="My Pipeline", + default_repo="user/my-pipeline", + default_dtype="float16", + ) + + # Access Mellon format dict + denoise = config.node_params["denoise"] + input_names = denoise["input_names"] + params = denoise["params"] + + # Save to Hub + config.save("./my_config", push_to_hub=True, repo_id="user/my-pipeline") + + # Load from Hub + loaded = MellonPipelineConfig.load("user/my-pipeline") + ``` + """ + + config_name = "mellon_pipeline_config.json" + + def __init__( + self, + node_specs: Dict[str, Optional[Dict[str, Any]]], + label: str = "", + default_repo: str = "", + default_dtype: str = "", + ): + """ + Args: + node_specs: Dict mapping node_type to node spec or None. + Node spec has: inputs, model_inputs, outputs, required_inputs, required_model_inputs, + block_name (all optional) + label: Human-readable label for the pipeline + default_repo: Default HuggingFace repo for this pipeline + default_dtype: Default dtype (e.g., "float16", "bfloat16") + """ + # Convert all node specs to Mellon format immediately + self.node_params = {} + for node_type, spec in node_specs.items(): + if spec is None: + self.node_params[node_type] = None + else: + self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type) + + self.label = label + self.default_repo = default_repo + self.default_dtype = default_dtype + + def __repr__(self) -> str: + node_types = list(self.node_params.keys()) + return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})" + + def to_dict(self) -> Dict[str, Any]: + """Convert to a JSON-serializable dictionary.""" + return { + "label": self.label, + "default_repo": self.default_repo, + "default_dtype": self.default_dtype, + "node_params": self.node_params, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MellonPipelineConfig": + """ + Create from a dictionary (loaded from JSON). + + Note: The mellon_params are already in Mellon format when loading from JSON. + """ + instance = cls.__new__(cls) + instance.node_params = data.get("node_params", {}) + instance.label = data.get("label", "") + instance.default_repo = data.get("default_repo", "") + instance.default_dtype = data.get("default_dtype", "") + return instance + + def to_json_string(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=False) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """Save to a JSON file.""" + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + @classmethod + def from_json_file(cls, json_file_path: Union[str, os.PathLike]) -> "MellonPipelineConfig": + """Load from a JSON file.""" + with open(json_file_path, "r", encoding="utf-8") as reader: + data = json.load(reader) + return cls.from_dict(data) + + def save(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """Save the pipeline config to a directory.""" + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + output_path = os.path.join(save_directory, self.config_name) + self.to_json_file(output_path) + logger.info(f"Pipeline config saved to {output_path}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + subfolder = kwargs.pop("subfolder", None) + + upload_folder( + repo_id=repo_id, + folder_path=save_directory, + token=token, + commit_message=commit_message or "Upload MellonPipelineConfig", + create_pr=create_pr, + path_in_repo=subfolder, + ) + logger.info(f"Pipeline config pushed to hub: {repo_id}") + + @classmethod + def load( cls, pretrained_model_name_or_path: Union[str, os.PathLike], - return_unused_kwargs=False, - return_commit_hash=False, **kwargs, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - r""" - Load a model or scheduler configuration. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with - [`~ConfigMixin.save_config`]. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - return_unused_kwargs (`bool`, *optional*, defaults to `False): - Whether unused keyword arguments of the config are returned. - return_commit_hash (`bool`, *optional*, defaults to `False): - Whether the `commit_hash` of the loaded configuration are returned. - - Returns: - `dict`: - A dictionary of all the parameters stored in a JSON configuration file. - - """ + ) -> "MellonPipelineConfig": + """Load a pipeline config from a local path or Hugging Face Hub.""" cache_dir = kwargs.pop("cache_dir", None) local_dir = kwargs.pop("local_dir", None) local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") @@ -450,27 +526,18 @@ class MellonNodeConfig(PushToHubMixin): token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from " - "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + if not os.path.isfile(config_file): + raise EnvironmentError(f"No file named {cls.config_name} found in {pretrained_model_name_or_path}") else: try: - # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, @@ -480,6 +547,7 @@ class MellonNodeConfig(PushToHubMixin): local_files_only=local_files_only, token=token, revision=revision, + subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) @@ -519,245 +587,8 @@ class MellonNodeConfig(PushToHubMixin): f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) + try: - with open(config_file, "r", encoding="utf-8") as reader: - text = reader.read() - config_dict = json.loads(text) - - commit_hash = extract_commit_hash(config_file) + return cls.from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") - - if not (return_unused_kwargs or return_commit_hash): - return config_dict - - outputs = (config_dict,) - - if return_unused_kwargs: - outputs += (kwargs,) - - if return_commit_hash: - outputs += (commit_hash,) - - return outputs - - def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save the Mellon node definition to a JSON file. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file is saved (will be created if it does not exist). - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - # If we save using the predefined names, we can load using `from_config` - output_config_file = os.path.join(save_directory, self.config_name) - - self.to_json_file(output_config_file) - logger.info(f"Mellon node definition saved in {output_config_file}") - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - private = kwargs.pop("private", None) - create_pr = kwargs.pop("create_pr", False) - token = kwargs.pop("token", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - subfolder = kwargs.pop("subfolder", None) - - self._upload_folder( - save_directory, - repo_id, - token=token, - commit_message=commit_message, - create_pr=create_pr, - subfolder=subfolder, - ) - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save the Mellon schema dictionary to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file to save a configuration instance's parameters. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string of the Mellon schema dict. - - Args: - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - - mellon_dict = self.to_mellon_dict() - return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n" - - def to_mellon_dict(self) -> Dict[str, Any]: - """Return a JSON-serializable dict focusing on the Mellon schema fields only. - - params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}. - """ - # inputs/model_inputs/outputs are already normalized dicts - merged_params = {} - merged_params.update(self.inputs or {}) - merged_params.update(self.model_inputs or {}) - merged_params.update(self.outputs or {}) - - return { - "node_type": self.node_type, - "blocks_names": self.blocks_names, - "params": merged_params, - } - - @classmethod - def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig": - """Create a config from a Mellon schema dict produced by to_mellon_dict(). - - Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from - MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by - default. - """ - flat_params = mellon_dict.get("params", {}) - - inputs: Dict[str, Any] = {} - model_inputs: Dict[str, Any] = {} - outputs: Dict[str, Any] = {} - - for param_name, param_dict in flat_params.items(): - if param_dict.get("display", "") == "output": - outputs[param_name] = param_dict - elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"): - model_inputs[param_name] = param_dict - else: - inputs[param_name] = param_dict - - return cls( - inputs=inputs, - model_inputs=model_inputs, - outputs=outputs, - blocks_names=mellon_dict.get("blocks_names", []), - node_type=mellon_dict.get("node_type"), - ) - - # YiYi Notes: not used yet - @classmethod - def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig": - """ - Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type, - use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs. - """ - if node_type not in NODE_TYPE_PARAMS_MAP: - raise ValueError(f"Node type {node_type} not supported") - - blocks_names = list(blocks.sub_blocks.keys()) - - default_node_config = NODE_TYPE_PARAMS_MAP[node_type] - inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", []) - model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", []) - outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", []) - - for required_input_name in blocks.required_inputs: - if required_input_name not in inputs_list: - inputs_list.append( - MellonParam( - name=required_input_name, label=required_input_name, type=required_input_name, display="input" - ) - ) - - for component_spec in blocks.expected_components: - if component_spec.name not in model_inputs_list: - model_inputs_list.append( - MellonParam( - name=component_spec.name, - label=component_spec.name, - type="diffusers_auto_model", - display="input", - ) - ) - - return cls( - inputs=inputs_list, - model_inputs=model_inputs_list, - outputs=outputs_list, - blocks_names=blocks_names, - node_type=node_type, - ) - - -# Minimal modular registry for Mellon node configs -class ModularMellonNodeRegistry: - """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig.""" - - def __init__(self): - self._registry = {} - self._initialized = False - - def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]): - if not self._initialized: - _initialize_registry(self) - self._registry[pipeline_cls] = node_params - - def get(self, pipeline_cls: type) -> MellonNodeConfig: - if not self._initialized: - _initialize_registry(self) - return self._registry.get(pipeline_cls, None) - - def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]: - if not self._initialized: - _initialize_registry(self) - return self._registry - - -def _register_preset_node_types( - pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry -): - """Register all node-type presets for a given pipeline class from a params map.""" - node_configs = {} - for node_type, spec in params_map.items(): - node_config = MellonNodeConfig( - inputs=spec.get("inputs", []), - model_inputs=spec.get("model_inputs", []), - outputs=spec.get("outputs", []), - blocks_names=spec.get("block_names", []), - node_type=node_type, - ) - node_configs[node_type] = node_config - registry.register(pipeline_cls, node_configs) - - -def _initialize_registry(registry: ModularMellonNodeRegistry): - """Initialize the registry and register all available pipeline configs.""" - print("Initializing registry") - - registry._initialized = True - - try: - from .qwenimage.modular_pipeline import QwenImageModularPipeline - from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register QwenImageModularPipeline") - - try: - from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline - from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register StableDiffusionXLModularPipeline") + raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.") diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 17c0117bff..c5fa4cf992 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -501,15 +501,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): @property def input_names(self) -> List[str]: - return [input_param.name for input_param in self.inputs] + return [input_param.name for input_param in self.inputs if input_param.name is not None] @property def intermediate_output_names(self) -> List[str]: - return [output_param.name for output_param in self.intermediate_outputs] + return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None] @property def output_names(self) -> List[str]: - return [output_param.name for output_param in self.outputs] + return [output_param.name for output_param in self.outputs if output_param.name is not None] + + @property + def component_names(self) -> List[str]: + return [component.name for component in self.expected_components] @property def doc(self): @@ -1525,10 +1529,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): if blocks is None: if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") - elif config_dict is not None: - blocks_class_name = self.get_default_blocks_name(config_dict) else: - blocks_class_name = None + blocks_class_name = self.get_default_blocks_name(config_dict) if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") blocks_class = getattr(diffusers_module, blocks_class_name) @@ -1625,7 +1627,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): return None, config_dict except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") + raise EnvironmentError( + f"Failed to load config from '{pretrained_model_name_or_path}'. " + f"Could not find or load 'modular_model_index.json' or 'model_index.json'." + ) from e return None, None @@ -2550,7 +2555,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: state.set(name, passed_kwargs.pop(name), kwargs_type) - elif name not in state.values: + elif kwargs_type is not None and kwargs_type in passed_kwargs: + kwargs_dict = passed_kwargs.pop(kwargs_type) + for k, v in kwargs_dict.items(): + state.set(k, v, kwargs_type) + elif name is not None and name not in state.values: state.set(name, default, kwargs_type) # Warn about unexpected inputs diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 26417162de..6e145f1855 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -30,6 +30,47 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier logger = logging.get_logger(__name__) +class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + return components + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor + ) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageDecoderStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -41,7 +82,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks): def expected_components(self) -> List[ComponentSpec]: components = [ ComponentSpec("vae", AutoencoderKLQwenImage), - ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), ] return components @@ -49,8 +89,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="height", required=True), - InputParam(name="width", required=True), InputParam( name="latents", required=True, @@ -74,10 +112,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks): block_state = self.get_block_state(state) # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular - vae_scale_factor = components.vae_scale_factor - block_state.latents = components.pachifier.unpack_latents( - block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor - ) + if block_state.latents.ndim == 4: + block_state.latents = block_state.latents.unsqueeze(dim=1) + elif block_state.latents.ndim != 5: + raise ValueError( + f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step." + ) block_state.latents = block_state.latents.to(components.vae.dtype) latents_mean = ( diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 55a7ae328f..dcce0cab5d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -26,7 +26,12 @@ from .before_denoise import ( QwenImageSetTimestepsStep, QwenImageSetTimestepsWithStrengthStep, ) -from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) from .denoise import ( QwenImageControlNetDenoiseStep, QwenImageDenoiseStep, @@ -92,6 +97,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -205,6 +211,7 @@ INPAINT_BLOCKS = InsertableDict( ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageInpaintDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageInpaintDecodeStep()), ] ) @@ -264,6 +271,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict( ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -529,8 +537,16 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): QwenImageAutoBeforeDenoiseStep, QwenImageOptionalControlNetBeforeDenoiseStep, QwenImageAutoDenoiseStep, + QwenImageAfterDenoiseStep, + ] + block_names = [ + "input", + "controlnet_input", + "before_denoise", + "controlnet_before_denoise", + "denoise", + "after_denoise", ] - block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"] @property def description(self): @@ -653,6 +669,7 @@ EDIT_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -702,6 +719,7 @@ EDIT_INPAINT_BLOCKS = InsertableDict( ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), ("denoise", QwenImageEditInpaintDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageInpaintDecodeStep()), ] ) @@ -841,8 +859,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): QwenImageEditAutoInputStep, QwenImageEditAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, + QwenImageAfterDenoiseStep, ] - block_names = ["input", "before_denoise", "denoise"] + block_names = ["input", "before_denoise", "denoise", "after_denoise"] @property def description(self): @@ -954,6 +973,7 @@ EDIT_PLUS_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -1037,8 +1057,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): QwenImageEditPlusAutoInputStep, QwenImageEditPlusAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, + QwenImageAfterDenoiseStep, ] - block_names = ["input", "before_denoise", "denoise"] + block_names = ["input", "before_denoise", "denoise", "after_denoise"] @property def description(self): diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py deleted file mode 100644 index 3230ece68a..0000000000 --- a/src/diffusers/modular_pipelines/qwenimage/node_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# mellon nodes -QwenImage_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": ["controlnet_vae_encoder"], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - "controlnet", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - }, -} diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py deleted file mode 100644 index 3e788bf947..0000000000 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -SDXL_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": [None], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - "block_names": ["vae_encoder"], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - "block_names": ["text_encoder"], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - "block_names": ["decode"], - }, -} diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py index ec815f77ad..3d5a00a9df 100644 --- a/src/diffusers/modular_pipelines/z_image/denoise.py +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -129,6 +129,10 @@ class ZImageLoopDenoiser(ModularPipelineBlocks): type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), ] guider_input_names = [] uncond_guider_input_names = [] diff --git a/src/diffusers/modular_pipelines/z_image/modular_blocks.py b/src/diffusers/modular_pipelines/z_image/modular_blocks.py index a7c520301a..a54baeccaf 100644 --- a/src/diffusers/modular_pipelines/z_image/modular_blocks.py +++ b/src/diffusers/modular_pipelines/z_image/modular_blocks.py @@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks): class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): block_classes = [ZImageVaeImageEncoderStep] - block_names = ["vae_image_encoder"] + block_names = ["vae_encoder"] block_trigger_inputs = ["image"] @property @@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks): ZImageAutoDenoiseStep, ZImageVaeDecoderStep, ] - block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] @property def description(self) -> str: @@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( IMAGE2IMAGE_BLOCKS = InsertableDict( [ ("text_encoder", ZImageTextEncoderStep), - ("vae_image_encoder", ZImageVaeImageEncoderStep), + ("vae_encoder", ZImageVaeImageEncoderStep), ("input", ZImageTextInputStep), ("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])), ("prepare_latents", ZImagePrepareLatentsStep), @@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict( AUTO_BLOCKS = InsertableDict( [ ("text_encoder", ZImageTextEncoderStep), - ("vae_image_encoder", ZImageAutoVaeImageEncoderStep), + ("vae_encoder", ZImageAutoVaeImageEncoderStep), ("denoise", ZImageAutoDenoiseStep), ("decode", ZImageVaeDecoderStep), ] From 262ce19bff6b19e38aed3519fc9eb2d90d24f87a Mon Sep 17 00:00:00 2001 From: MatrixTeam-AI Date: Sat, 20 Dec 2025 07:10:40 +0800 Subject: [PATCH 3/8] Feature: Add Mambo-G Guidance as Guider (#12862) * Feature: Add Mambo-G Guidance to Qwen-Image Pipeline * change to guider implementation * fix copied code residual * Update src/diffusers/guiders/magnitude_aware_guidance.py * Apply style fixes --------- Co-authored-by: Pscgylotti Co-authored-by: YiYi Xu Co-authored-by: github-actions[bot] --- src/diffusers/guiders/__init__.py | 1 + .../guiders/magnitude_aware_guidance.py | 159 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 src/diffusers/guiders/magnitude_aware_guidance.py diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 4e53c373c4..58ad0c211b 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -25,6 +25,7 @@ if is_torch_available(): from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .guider_utils import BaseGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py new file mode 100644 index 0000000000..b81cf0d3a1 --- /dev/null +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class MagnitudeAwareGuidance(BaseGuidance): + """ + Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442 + + Args: + guidance_scale (`float`, defaults to `10.0`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + alpha (`float`, defaults to `8.0`): + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of + guidance scale when the magnitude of the guidance update is large. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 10.0, + alpha: float = 8.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + enabled: bool = True, + ): + super().__init__(start, stop, enabled) + + self.guidance_scale = guidance_scale + self.alpha = alpha + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch(data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: + pred = None + + if not self._is_mambo_g_enabled(): + pred = pred_cond + else: + pred = mambo_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.alpha, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_mambo_g_enabled(): + num_conditions += 1 + return num_conditions + + def _is_mambo_g_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def mambo_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + alpha: float = 8.0, + use_original_formulation: bool = False, +): + dim = list(range(1, len(pred_cond.shape))) + diff = pred_cond - pred_uncond + ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) + guidance_scale_final = ( + guidance_scale * torch.exp(-alpha * ratio) + if use_original_formulation + else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio) + ) + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale_final * diff + + return pred From 0c4f6c9cffb9df58d187f88e6434be5e03d3b8ac Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Tue, 23 Dec 2025 02:14:03 +0900 Subject: [PATCH 4/8] Add `OvisImagePipeline` in `AUTO_TEXT2IMAGE_PIPELINES_MAPPING` (#12876) --- src/diffusers/pipelines/auto_pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index db0268a2a7..4106a8fda7 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -73,6 +73,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .lumina import LuminaPipeline from .lumina2 import Lumina2Pipeline +from .ovis_image import OvisImagePipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, @@ -164,6 +165,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), + ("ovis", OvisImagePipeline), ] ) From 973a077c6a4e7e7a7ea61a84bedd29ac24fb609a Mon Sep 17 00:00:00 2001 From: Miguel Martin Date: Mon, 22 Dec 2025 10:02:06 -0800 Subject: [PATCH 5/8] Cosmos Predict2.5 14b Conversion (#12863) 14b conversion --- scripts/convert_cosmos_to_diffusers.py | 60 ++++++++++++++++++- .../cosmos/pipeline_cosmos2_5_predict.py | 2 +- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index 6e70f8cc05..bc6014068e 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -29,13 +29,52 @@ hf download nvidia/Cosmos-Predict2.5-2B Convert checkpoint ```bash +# pre-trained transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt python scripts/convert_cosmos_to_diffusers.py \ --transformer_type Cosmos-2.5-Predict-Base-2B \ --transformer_ckpt_path $transformer_ckpt_path \ --vae_type wan2.1 \ - --output_path converted/cosmos-p2.5-base-2b \ + --output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \ + --save_pipeline + +# post-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \ + --save_pipeline +``` + +## 14B + +```bash +hf download nvidia/Cosmos-Predict2.5-14B +``` + +```bash +# pre-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-14B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \ + --save_pipeline + +# post-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-14B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \ --save_pipeline ``` @@ -298,6 +337,25 @@ TRANSFORMER_CONFIGS = { "crossattn_proj_in_channels": 100352, "encoder_hidden_states_channels": 1024, }, + "Cosmos-2.5-Predict-Base-14B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 40, + "attention_head_dim": 128, + "num_layers": 36, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + # NOTE: source config has pos_emb_learnable: 'True' - but params are missing + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, } VAE_KEYS_RENAME_DICT = { diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py index 6564b59373..372684e0b5 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """ ... num_frames=93, ... generator=torch.Generator().manual_seed(1), ... ).frames[0] - >>> # export_to_video(video, "image2world.mp4", fps=16) + >>> export_to_video(video, "image2world.mp4", fps=16) >>> # Video2World: condition on an input clip and predict a 93-frame world video. >>> prompt = ( From 52766e6a6939ac6e74375bde5e19c5e0b90d24c1 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 24 Dec 2025 01:57:41 +0900 Subject: [PATCH 6/8] Use `T5Tokenizer` instead of `MT5Tokenizer` (removed in Transformers v5.0+) (#12877) Use `T5Tokenizer` instead of `MT5Tokenizer` Given that the `MT5Tokenizer` in `transformers` is just a "re-export" of `T5Tokenizer` as per https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/mt5/tokenization_mt5.py )on latest available stable Transformers i.e., v4.57.3), this commit updates the imports to point to `T5Tokenizer` instead, so that those still work with Transformers v5.0.0rc0 onwards. --- .../community/pipeline_hunyuandit_differential_img2img.py | 6 +++--- .../controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py | 6 +++--- src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py | 6 +++--- src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index fb7a4cb5e4..bc6841525b 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -21,8 +21,8 @@ from transformers import ( BertModel, BertTokenizer, CLIPImageProcessor, - MT5Tokenizer, T5EncoderModel, + T5Tokenizer, ) from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + tokenizer_2=T5Tokenizer, ): super().__init__() diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 2b5684de95..29a7d61476 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline): HunyuanDiT2DMultiControlNetModel, ], text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index e2f935aaf4..052c7b4739 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index d156eac8f3..6704924b25 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): feature_extractor: Optional[CLIPImageProcessor] = None, requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() From f6b6a7181eb44f0120b29cd897c129275f366c2a Mon Sep 17 00:00:00 2001 From: RuoyiDu <61931443+RuoyiDu@users.noreply.github.com> Date: Wed, 24 Dec 2025 17:45:35 +0800 Subject: [PATCH 7/8] Add z-image-omni-base implementation (#12857) * Add z-image-omni-base implementation * Merged into one transformer for Z-Image. * Fix bugs for controlnet after merging the main branch new feature. * Fix for auto_pipeline, Add Styling. * Refactor noise handling and modulation - Add select_per_token function for per-token value selection - Separate adaptive modulation logic - Cleanify t_noisy/clean variable naming - Move image_noise_mask handler from forward to pipeline * Styling & Formatting. * Rewrite code with more non-forward func & clean forward. 1.Change to one forward with shorter code with omni code (None). 2.Split out non-forward funcs: _build_unified_sequence, _prepare_sequence, patchify, pad. * Styling & Formatting. * Manual check fix-copies in controlnet, Add select_per_token, _patchify_image, _pad_with_ids; Styling. * Add Import in pipeline __init__.py. --------- Co-authored-by: Jerry Qilong Wu Co-authored-by: YiYi Xu --- src/diffusers/__init__.py | 2 + .../models/controlnets/controlnet_z_image.py | 213 ++--- .../transformers/transformer_z_image.py | 806 +++++++++++++----- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 11 +- src/diffusers/pipelines/z_image/__init__.py | 3 +- .../z_image/pipeline_z_image_omni.py | 742 ++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 8 files changed, 1483 insertions(+), 311 deletions(-) create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_omni.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6aac3feffd..aa11a741af 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -675,6 +675,7 @@ else: "ZImageControlNetInpaintPipeline", "ZImageControlNetPipeline", "ZImageImg2ImgPipeline", + "ZImageOmniPipeline", "ZImagePipeline", ] ) @@ -1386,6 +1387,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 54e398ea13..3f79ec9254 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Tuple import torch import torch.nn as nn @@ -170,6 +170,21 @@ class FeedForward(nn.Module): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) +# Copied from diffusers.models.transformers.transformer_z_image.select_per_token +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + @maybe_allow_in_graph # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock class ZImageTransformerBlock(nn.Module): @@ -220,12 +235,37 @@ class ZImageTransformerBlock(nn.Module): attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( @@ -493,112 +533,93 @@ class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi def create_coordinate_grid(size, start=None, device=None): if start is None: start = (0 for _ in size) - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed def patchify_and_embed( - self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int ): - pH = pW = patch_size - pF = f_patch_size + """Patchify for basic mode: single image per batch item.""" device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - cap_pad_mask = torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_cap_pad_mask.append( - cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device ) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) - # padded feature - cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padded_pos_ids = torch.cat( - [ - image_ori_pos_ids, - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) - .flatten(0, 2) - .repeat(image_padding_len, 1), - ], - dim=0, + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device ) - all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) - # pad mask - image_pad_mask = torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_image_pad_mask.append( - image_pad_mask - if image_padding_len > 0 - else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) - ) - # padded feature - image_padded_feat = torch.cat( - [image, image[-1:].repeat(image_padding_len, 1)], - dim=0, - ) - all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, all_cap_pos_ids, - all_image_pad_mask, + all_img_pad_mask, all_cap_pad_mask, ) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 17197db3a4..5983c34ab6 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -32,6 +32,7 @@ from ..modeling_outputs import Transformer2DModelOutput ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 class TimestepEmbedder(nn.Module): @@ -152,6 +153,20 @@ class ZSingleStreamAttnProcessor: return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() @@ -215,12 +230,37 @@ class ZImageTransformerBlock(nn.Module): attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( @@ -252,9 +292,21 @@ class FinalLayer(nn.Module): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -325,6 +377,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, + siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], @@ -386,6 +439,31 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -402,259 +480,561 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): if start is None: start = (0 for _ in size) - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) - def patchify_and_embed( + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def _pad_with_ids( self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, ): - pH = pW = patch_size - pF = f_patch_size + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def patchify_and_embed( + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int + ): + """Patchify for basic mode: single image per batch item.""" device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - cap_pad_mask = torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_cap_pad_mask.append( - cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device ) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) - # padded feature - cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padded_pos_ids = torch.cat( - [ - image_ori_pos_ids, - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) - .flatten(0, 2) - .repeat(image_padding_len, 1), - ], - dim=0, + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device ) - all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) - # pad mask - image_pad_mask = torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_image_pad_mask.append( - image_pad_mask - if image_padding_len > 0 - else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) - ) - # padded feature - image_padded_feat = torch.cat( - [image, image[-1:].repeat(image_padding_len, 1)], - dim=0, - ) - all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, all_cap_pos_ids, - all_image_pad_mask, + all_img_pad_mask, all_cap_pad_mask, ) - def forward( + def patchify_and_embed_omni( self, - x: List[torch.Tensor], - t, - cap_feats: List[torch.Tensor], - controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, - patch_size=2, - f_patch_size=1, - return_dict: bool = True, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype - bsz = len(x) - device = x[0].device - t = t * self.t_scale - t = self.t_embedder(t) + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) - # Match t_embedder output dtype to x for layerwise casting compatibility - adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) - else: - for layer in self.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - cap_max_item_seqlen = max(cap_item_seqlens) + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list( - self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, ) - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats = list(feats_cat.split(item_seqlens, dim=0)) - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) - # unified + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) unified = [] - unified_freqs_cis = [] + unified_freqs = [] + unified_noise_mask = [] + for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) + x_len, cap_len = x_seqlens[i], cap_seqlens[i] - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer_idx, layer in enumerate(self.layers): - unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input - ) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] else: - for layer_idx, layer in enumerate(self.layers): - unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) - unified = list(unified.unbind(dim=0)) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + max_seqlen = max(unified_seqlens) - if not return_dict: - return (x,) + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) - return Transformer2DModelOutput(sample=x) + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def forward( + self, + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], + t, + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], + return_dict: bool = True, + controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 2, + f_patch_size: int = 1, + ): + """ + Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine + -> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify + """ + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # X embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) + + for layer in self.noise_refiner: + x = ( + self._gradient_checkpointing_func( + layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean) + ) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) + + for layer in self.context_refiner: + cap_feats = ( + self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(cap_feats, cap_mask, cap_freqs) + ) + + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) + + for layer in self.siglip_refiner: + siglip_feats = ( + self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(siglip_feats, siglip_mask, siglip_freqs) + ) + + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = ( + self._gradient_checkpointing_func( + layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean) + ) + if controlnet_block_samples is not None and layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) + + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return (x,) if not return_dict else Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e8faf868e7..f7615c1a44 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -411,6 +411,7 @@ else: "ZImagePipeline", "ZImageControlNetPipeline", "ZImageControlNetInpaintPipeline", + "ZImageOmniPipeline", ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", @@ -856,6 +857,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4106a8fda7..c14910250b 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -120,7 +120,13 @@ from .stable_diffusion_xl import ( ) from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline -from .z_image import ZImageImg2ImgPipeline, ZImagePipeline +from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, + ZImagePipeline, +) AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -165,6 +171,9 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), + ("z-image-controlnet", ZImageControlNetPipeline), + ("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), + ("z-image-omni", ZImageOmniPipeline), ("ovis", OvisImagePipeline), ] ) diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index 7b3cfbceea..78bd3bfacb 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -26,6 +26,7 @@ else: _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -41,7 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_z_image_controlnet import ZImageControlNetPipeline from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline from .pipeline_z_image_img2img import ZImageImg2ImgPipeline - + from .pipeline_z_image_omni import ZImageOmniPipeline else: import sys diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py new file mode 100644 index 0000000000..26848bea0a --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -0,0 +1,742 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel + +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..flux2.image_processor import Flux2ImageProcessor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: Siglip2ImageProcessorFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + def prepare_siglip_embeds( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) + ] + + model_out_list = self.transformer( + x=x_combined, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, + return_dict=False, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4e1eae211c..6c28e87581 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3917,6 +3917,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class ZImageOmniPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ZImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 1cdb8723b85f1b427031e390e0bd0bebfe92454e Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 29 Dec 2025 23:34:54 -0500 Subject: [PATCH 8/8] fix torchao quantizer for new torchao versions (#12901) * fix torchao quantizer for new torchao versions Summary: `torchao==0.16.0` (not yet released) has some bc-breaking changes, this PR fixes the diffusers repo with those changes. Specifics on the changes: 1. `UInt4Tensor` is removed: https://github.com/pytorch/ao/pull/3536 2. old float8 tensors v1 are removed: https://github.com/pytorch/ao/pull/3510 In this PR: 1. move the logger variable up (not sure why it was in the middle of the file before) to get better error messages 2. gate the old torchao objects by torchao version Test Plan: import diffusers objects with new versions of torchao works: ```bash > python -c "import torchao; print(torchao.__version__); from diffusers import StableDiffusionPipeline" 0.16.0.dev20251229+cu129 ``` Reviewers: Subscribers: Tasks: Tags: * Apply style fixes --------- Co-authored-by: github-actions[bot] --- .../quantizers/torchao/torchao_quantizer.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 2334c7af86..0405afdaae 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -36,6 +36,9 @@ from ...utils import ( from ..base import DiffusersQuantizer +logger = logging.get_logger(__name__) + + if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin @@ -83,11 +86,19 @@ def _update_torch_safe_globals(): ] try: from torchao.dtypes import NF4Tensor - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor - safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) + safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor]) + + # note: is_torchao_version(">=", "0.16.0") does not work correctly + # with torchao nightly, so using a ">" check which does work correctly + if is_torchao_version(">", "0.15.0"): + pass + else: + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + + safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl]) except (ImportError, ModuleNotFoundError) as e: logger.warning( @@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]: return None -logger = logging.get_logger(__name__) - - def _quantization_type(weight): from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor