From b5309683cb6753e2111be9a8204f90a550c3fcb6 Mon Sep 17 00:00:00 2001 From: Miguel Martin Date: Thu, 18 Dec 2025 16:08:18 -0800 Subject: [PATCH 1/3] Cosmos Predict2.5 Base: inference pipeline, scheduler & chkpt conversion (#12852) * cosmos predict2.5 base: convert chkpt & pipeline - New scheduler: scheduling_flow_unipc_multistep.py - Changes to TransformerCosmos for text embeddings via crossattn_proj * scheduler cleanup * simplify inference pipeline * cleanup scheduler + tests * Basic tests for flow unipc * working b2b inference * Rename everything * Tests for pipeline present, but not working (predict2 also not working) * docstring update * wrapper pipelines + make style * remove unnecessary files * UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True * use UniPCMultistepScheduler + fix tests for pipeline * Remove FlowUniPCMultistepScheduler * UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True * num_inference_steps=36 due to bug in scheduler used by predict2.5 * Address comments * make style + make fix-copies * fix tests + remove references to old pipelines * address comments * add revision in from_pretrained call * fix tests --- docs/source/en/api/pipelines/cosmos.md | 6 + scripts/convert_cosmos_to_diffusers.py | 135 ++- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_cosmos.py | 13 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/cosmos/__init__.py | 6 + .../cosmos/pipeline_cosmos2_5_predict.py | 847 ++++++++++++++++++ .../schedulers/scheduling_unipc_multistep.py | 9 +- .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/cosmos/cosmos_guardrail.py | 11 +- .../cosmos/test_cosmos2_5_predict.py | 337 +++++++ tests/schedulers/test_scheduler_unipc.py | 29 + 12 files changed, 1398 insertions(+), 14 deletions(-) create mode 100644 src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py create mode 100644 tests/pipelines/cosmos/test_cosmos2_5_predict.py diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md index fb9453480e..60ecce6603 100644 --- a/docs/source/en/api/pipelines/cosmos.md +++ b/docs/source/en/api/pipelines/cosmos.md @@ -70,6 +70,12 @@ output.save("output.png") - all - __call__ +## Cosmos2_5_PredictBasePipeline + +[[autodoc]] Cosmos2_5_PredictBasePipeline + - all + - __call__ + ## CosmosPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index 6f6563ad64..6e70f8cc05 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -1,11 +1,55 @@ +""" +# Cosmos 2 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2-2B-Text2Image +``` + +convert checkpoint +```bash +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_ckpt_path $transformer_ckpt_path \ + --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \ + --text_encoder_path google-t5/t5-11b \ + --tokenizer_path google-t5/t5-11b \ + --vae_type wan2.1 \ + --output_path converted/cosmos-p2-t2i-2b \ + --save_pipeline +``` + +# Cosmos 2.5 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2.5-2B +``` + +Convert checkpoint +```bash +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/cosmos-p2.5-base-2b \ + --save_pipeline +``` + +""" + import argparse import pathlib +import sys from typing import Any, Dict import torch from accelerate import init_empty_weights from huggingface_hub import snapshot_download -from transformers import T5EncoderModel, T5TokenizerFast +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast from diffusers import ( AutoencoderKLCosmos, @@ -17,7 +61,9 @@ from diffusers import ( CosmosVideoToWorldPipeline, EDMEulerScheduler, FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, ) +from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -233,6 +279,25 @@ TRANSFORMER_CONFIGS = { "concat_padding_mask": True, "extra_pos_embed_type": None, }, + "Cosmos-2.5-Predict-Base-2B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 28, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + # NOTE: source config has pos_emb_learnable: 'True' - but params are missing + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, } VAE_KEYS_RENAME_DICT = { @@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo elif "Cosmos-2.0" in transformer_type: TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 + elif "Cosmos-2.5" in transformer_type: + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 else: assert False @@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo new_key = new_key.removeprefix(PREFIX_KEY) for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) + print(key, "->", new_key, flush=True) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): @@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo continue handler_fn_inplace(key, original_state_dict) + expected_keys = set(transformer.state_dict().keys()) + mapped_keys = set(original_state_dict.keys()) + missing_keys = expected_keys - mapped_keys + unexpected_keys = mapped_keys - expected_keys + if missing_keys: + print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr) + for k in missing_keys: + print(k) + sys.exit(1) + if unexpected_keys: + print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr) + for k in unexpected_keys: + print(k) + sys.exit(2) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer @@ -444,6 +528,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") +def save_pipeline_cosmos2_5(args, transformer, vae): + text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B" + tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct" + + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + text_encoder_path, torch_dtype="auto", device_map="cpu" + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + scheduler = UniPCMultistepScheduler( + use_karras_sigmas=True, + use_flow_sigmas=True, + prediction_type="flow_prediction", + sigma_max=200.0, + sigma_min=0.01, + ) + + pipe = Cosmos2_5_PredictBasePipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vae=vae, + scheduler=scheduler, + safety_checker=lambda *args, **kwargs: None, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) @@ -451,10 +563,10 @@ def get_args(): "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument( - "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE" + "--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE" ) - parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b") - parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b") + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--tokenizer_path", type=str, default=None) parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") @@ -477,8 +589,6 @@ if __name__ == "__main__": if args.save_pipeline: assert args.transformer_ckpt_path is not None assert args.vae_type is not None - assert args.text_encoder_path is not None - assert args.tokenizer_path is not None if args.transformer_ckpt_path is not None: weights_only = "Cosmos-1.0" in args.transformer_type @@ -490,17 +600,26 @@ if __name__ == "__main__": if args.vae_type is not None: if "Cosmos-1.0" in args.transformer_type: vae = convert_vae(args.vae_type) - else: + elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type: vae = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 ) + else: + raise AssertionError(f"{args.transformer_type} not supported") + if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.save_pipeline: if "Cosmos-1.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_1_0(args, transformer, vae) elif "Cosmos-2.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_2_0(args, transformer, vae) + elif "Cosmos-2.5" in args.transformer_type: + save_pipeline_cosmos2_5(args, transformer, vae) else: - assert False + raise AssertionError(f"{args.transformer_type} not supported") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 03ecaf6bc1..6aac3feffd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -463,6 +463,7 @@ else: "CogView4ControlPipeline", "CogView4Pipeline", "ConsisIDPipeline", + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "CosmosTextToWorldPipeline", @@ -1175,6 +1176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: CogView4ControlPipeline, CogView4Pipeline, ConsisIDPipeline, + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 373b470ae3..2b0c266707 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -439,6 +439,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0), concat_padding_mask: bool = True, extra_pos_embed_type: Optional[str] = "learnable", + use_crossattn_projection: bool = False, + crossattn_proj_in_channels: int = 1024, + encoder_hidden_states_channels: int = 1024, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -485,6 +488,12 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False ) + if self.config.use_crossattn_projection: + self.crossattn_proj = nn.Sequential( + nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True), + nn.GELU(), + ) + self.gradient_checkpointing = False def forward( @@ -524,6 +533,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w + hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C] @@ -546,6 +556,9 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): else: assert False + if self.config.use_crossattn_projection: + encoder_hidden_states = self.crossattn_proj(encoder_hidden_states) + # 5. Transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 04ec6b5cd8..e8faf868e7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -165,6 +165,7 @@ else: _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["cosmos"] = [ + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", @@ -622,6 +623,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLControlNetXSPipeline, ) from .cosmos import ( + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 2833c89abd..944f165531 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -22,6 +22,9 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["pipeline_cosmos2_5_predict"] = [ + "Cosmos2_5_PredictBasePipeline", + ] _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] @@ -35,6 +38,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .pipeline_cosmos2_5_predict import ( + Cosmos2_5_PredictBasePipeline, + ) from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py new file mode 100644 index 0000000000..6564b59373 --- /dev/null +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -0,0 +1,847 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2_5_PredictBasePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "nvidia/Cosmos-Predict2.5-2B" + >>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + ... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Common negative prompt reused across modes. + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + + >>> # Text2World: generate a 93-frame world video from text only. + >>> prompt = ( + ... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights " + ... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh " + ... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet " + ... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. " + ... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow " + ... "advance of traffic through the frosty city corridor." + ... ) + >>> video = pipe( + ... image=None, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "text2world.mp4", fps=16) + + >>> # Image2World: condition on a single image and generate a 93-frame world video. + >>> prompt = ( + ... "A high-definition video captures the precision of robotic welding in an industrial setting. " + ... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. " + ... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid " + ... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring " + ... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a " + ... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video " + ... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. " + ... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. " + ... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with " + ... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation." + ... ) + >>> image = load_image( + ... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg" + ... ) + >>> video = pipe( + ... image=image, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> # export_to_video(video, "image2world.mp4", fps=16) + + >>> # Video2World: condition on an input clip and predict a 93-frame world video. + >>> prompt = ( + ... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles " + ... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the " + ... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green " + ... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. " + ... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along " + ... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame " + ... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet " + ... "steady pace of the construction activity." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4" + ... ) + >>> video = pipe( + ... image=None, + ... video=input_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "video2world.mp4", fps=16) + + >>> # To produce an image instead of a world (video) clip, set num_frames=1 and + >>> # save the first frame: pipe(..., num_frames=1).frames[0][0]. + ``` +""" + + +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): + r""" + Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: List[PipelineImageInput] | None = None, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + num_frames_in = len(video) + + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # pad with last frame (for video2world) + num_frames_out = num_frames + if video.shape[2] < num_frames_out: + n_pad_frames = num_frames_out - num_frames_in + last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) + + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + noise_pred_neg = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 689c6a0635..5ea56b300b 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -217,6 +217,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential"] = "exponential", + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -350,7 +352,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.use_flow_sigmas: + sigmas = sigmas / (sigmas + 1) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + else: + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 74a4146bfd..4e1eae211c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -767,6 +767,21 @@ class ConsisIDPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Cosmos2TextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/cosmos/cosmos_guardrail.py b/tests/pipelines/cosmos/cosmos_guardrail.py index 4de14fbaaf..c9ef597fdb 100644 --- a/tests/pipelines/cosmos/cosmos_guardrail.py +++ b/tests/pipelines/cosmos/cosmos_guardrail.py @@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin): def __init__(self) -> None: super().__init__() - self._dtype = torch.float32 + self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False) def check_text_safety(self, prompt: str) -> bool: return True @@ -35,13 +35,14 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin): def check_video_safety(self, frames: np.ndarray) -> np.ndarray: return frames - def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None: - self._dtype = dtype + def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None): + module = super().to(device=device, dtype=dtype) + return module @property def device(self) -> torch.device: - return None + return self._device_tracker.device @property def dtype(self) -> torch.dtype: - return self._dtype + return self._device_tracker.dtype diff --git a/tests/pipelines/cosmos/test_cosmos2_5_predict.py b/tests/pipelines/cosmos/test_cosmos2_5_predict.py new file mode 100644 index 0000000000..54d4edb485 --- /dev/null +++ b/tests/pipelines/cosmos/test_cosmos2_5_predict.py @@ -0,0 +1,337 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +import tempfile +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLWan, + Cosmos2_5_PredictBasePipeline, + CosmosTransformer3DModel, + UniPCMultistepScheduler, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np +from .cosmos_guardrail import DummyCosmosSafetyChecker + + +enable_full_determinism() + + +class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline): + @staticmethod + def from_pretrained(*args, **kwargs): + if "safety_checker" not in kwargs or kwargs["safety_checker"] is None: + safety_checker = DummyCosmosSafetyChecker() + device_map = kwargs.get("device_map", "cpu") + torch_dtype = kwargs.get("torch_dtype") + if device_map is not None or torch_dtype is not None: + safety_checker = safety_checker.to(device_map, dtype=torch_dtype) + kwargs["safety_checker"] = safety_checker + return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs) + + +class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Cosmos2_5_PredictBaseWrapper + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=16 + 1, + out_channels=16, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=32, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(2.0, 1.0, 1.0), + concat_padding_mask=True, + extra_pos_embed_type="learnable", + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": DummyCosmosSafetyChecker(), + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "num_frames": 3, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))} + pipe = self.pipeline_class(**init_components) + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not getattr(self, "test_attention_slicing", True): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + self.pipeline_class._optional_components.remove("safety_checker") + super().test_save_load_optional_components(expected_max_difference=expected_max_difference) + self.pipeline_class._optional_components.append("safety_checker") + + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name + for component_name, component in pipe.components.items() + if isinstance(component, torch.nn.Module) + ] + model_components.remove("safety_checker") + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname, safe_serialization=False) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained( + tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict + ) + + for name, component in loaded_pipe.components.items(): + if name == "safety_checker": + continue + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + + @unittest.skip( + "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in " + "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is " + "too large and slow to run on CI." + ) + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 197c831cb0..ac7e1d3f88 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -399,3 +399,32 @@ class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest): def test_exponential_sigmas(self): self.check_over_configs(use_exponential_sigmas=True) + + def test_flow_and_karras_sigmas(self): + self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True) + + def test_flow_and_karras_sigmas_values(self): + num_train_timesteps = 1000 + num_inference_steps = 5 + scheduler = UniPCMultistepScheduler( + sigma_min=0.01, + sigma_max=200.0, + use_flow_sigmas=True, + use_karras_sigmas=True, + num_train_timesteps=num_train_timesteps, + ) + scheduler.set_timesteps(num_inference_steps=num_inference_steps) + + expected_sigmas = [ + 0.9950248599052429, + 0.9787454605102539, + 0.8774884343147278, + 0.3604971766471863, + 0.009900986216962337, + 0.0, # 0 appended as default + ] + expected_sigmas = torch.tensor(expected_sigmas) + expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64) + expected_timesteps = expected_timesteps[0:-1] + self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas)) + self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps)) From f7753b1bc8b4b3b97dc7f71d51ccb3a281b17b48 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 18 Dec 2025 19:25:20 -1000 Subject: [PATCH 2/3] more update in modular (#12560) * move node registry to mellon * up * fix * modula rpipeline update: filter out none for input_names, fix default blocks for pipe.init() and allow user pass additional kwargs_type in a dict * qwen modular refactor, unpack before decode * update mellon node config, adding* to required_inputs and required_model_inputs * modularpipeline.from_pretrained: error out if no config found * add a component_names property to modular blocks to be consistent! * flux image_encoder -> vae_encoder * controlnet_bundle * refator MellonNodeConfig MellonPipelineConfig * refactor & simplify mellon utils * vae_image_encoder -> vae_encoder * mellon config save keep key order * style + copies * add kwargs input for zimage --- .../modular_pipelines/flux/modular_blocks.py | 4 +- .../modular_pipelines/mellon_node_utils.py | 1141 +++++++---------- .../modular_pipelines/modular_pipeline.py | 25 +- .../modular_pipelines/qwenimage/decoders.py | 54 +- .../qwenimage/modular_blocks.py | 29 +- .../modular_pipelines/qwenimage/node_utils.py | 95 -- .../stable_diffusion_xl/node_utils.py | 99 -- .../modular_pipelines/z_image/denoise.py | 4 + .../z_image/modular_blocks.py | 8 +- 9 files changed, 585 insertions(+), 874 deletions(-) delete mode 100644 src/diffusers/modular_pipelines/qwenimage/node_utils.py delete mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index a80bc2a5f7..bd9b2d1b40 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -360,7 +360,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks): AUTO_BLOCKS = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxAutoVaeEncoderStep()), + ("vae_encoder", FluxAutoVaeEncoderStep()), ("denoise", FluxCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] @@ -369,7 +369,7 @@ AUTO_BLOCKS = InsertableDict( AUTO_BLOCKS_KONTEXT = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxKontextAutoVaeEncoderStep()), + ("vae_encoder", FluxKontextAutoVaeEncoderStep()), ("denoise", FluxKontextCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index a405aebee2..4f142a453f 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -4,315 +4,31 @@ import os # Simple typed wrapper for parameter overrides from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import ( EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, - validate_hf_hub_args, ) -from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash -from .modular_pipeline import ModularPipelineBlocks +from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT logger = logging.getLogger(__name__) -SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"} - - -# Mellon Input Parameters (runtime parameters, not models) -MELLON_INPUT_PARAMS = { - # controlnet - "control_image": { - "label": "Control Image", - "type": "image", - "display": "input", - }, - "controlnet_conditioning_scale": { - "label": "Scale", - "type": "float", - "default": 0.5, - "min": 0, - "max": 1, - }, - "control_guidance_end": { - "label": "End", - "type": "float", - "default": 1.0, - "min": 0, - "max": 1, - }, - "control_guidance_start": { - "label": "Start", - "type": "float", - "default": 0.0, - "min": 0, - "max": 1, - }, - "controlnet": { - "label": "Controlnet", - "type": "custom_controlnet", - "display": "input", - }, - "embeddings": { - "label": "Text Embeddings", - "display": "input", - "type": "embeddings", - }, - "image": { - "label": "Image", - "type": "image", - "display": "input", - }, - "negative_prompt": { - "label": "Negative Prompt", - "type": "string", - "default": "", - "display": "textarea", - }, - "prompt": { - "label": "Prompt", - "type": "string", - "default": "", - "display": "textarea", - }, - "guidance_scale": { - "label": "Guidance Scale", - "type": "float", - "display": "slider", - "default": 5, - "min": 1.0, - "max": 30.0, - "step": 0.1, - }, - "height": { - "label": "Height", - "type": "int", - "default": 1024, - "min": 64, - "step": 8, - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "input", - "onChange": {False: ["height", "width"], True: ["strength"]}, - }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "input", - }, - "num_inference_steps": { - "label": "Steps", - "type": "int", - "display": "slider", - "default": 25, - "min": 1, - "max": 100, - }, - "seed": { - "label": "Seed", - "type": "int", - "display": "random", - "default": 0, - "min": 0, - "max": 4294967295, - }, - "strength": { - "label": "Strength", - "type": "float", - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - }, - "width": { - "label": "Width", - "type": "int", - "default": 1024, - "min": 64, - "step": 8, - }, - "ip_adapter": { - "label": "IP Adapter", - "type": "custom_ip_adapter", - "display": "input", - }, -} - -# Mellon Model Parameters (diffusers_auto_model types) -MELLON_MODEL_PARAMS = { - "scheduler": { - "label": "Scheduler", - "display": "input", - "type": "diffusers_auto_model", - }, - "text_encoders": { - "label": "Text Encoders", - "type": "diffusers_auto_models", - "display": "input", - }, - "unet": { - "label": "Unet", - "display": "input", - "type": "diffusers_auto_model", - "onSignal": { - "action": "signal", - "target": "guider", - }, - }, - "guider": { - "label": "Guider", - "display": "input", - "type": "custom_guider", - "onChange": {False: ["guidance_scale"], True: []}, - }, - "vae": { - "label": "VAE", - "display": "input", - "type": "diffusers_auto_model", - }, - "controlnet": { - "label": "Controlnet Model", - "type": "diffusers_auto_model", - "display": "input", - }, -} - -# Mellon Output Parameters (display = "output") -MELLON_OUTPUT_PARAMS = { - "embeddings": { - "label": "Text Embeddings", - "display": "output", - "type": "embeddings", - }, - "images": { - "label": "Images", - "type": "image", - "display": "output", - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "output", - }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "output", - }, - "latents_preview": { - "label": "Latents Preview", - "display": "output", - "type": "latent", - }, - "controlnet_out": { - "label": "Controlnet", - "display": "output", - "type": "controlnet", - }, -} - - -# Default param selections per supported node_type -# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. -NODE_TYPE_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet", - ], - "block_names": ["controlnet_vae_encoder"], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - "block_names": ["vae_encoder"], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - # optional image prompt input supported in embeddings node - "image", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - "block_names": ["text_encoder"], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - "block_names": ["decode"], - }, -} - - @dataclass(frozen=True) class MellonParam: + """ + Parameter definition for Mellon nodes. + + Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...", + label="...", type="..."). + """ + name: str label: str type: str @@ -326,122 +42,482 @@ class MellonParam: fieldOptions: Optional[Dict[str, Any]] = None onChange: Any = None onSignal: Any = None - _map_to_input: Any = None # the block input name this parameter maps to def to_dict(self) -> Dict[str, Any]: + """Convert to dict for Mellon schema, excluding None values and name.""" data = asdict(self) - return {k: v for k, v in data.items() if not k.startswith("_") and v is not None} - - -@dataclass -class MellonNodeConfig(PushToHubMixin): - """ - A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers. - - - - This is an experimental feature and is likely to change in the future. - - - """ - - inputs: List[Union[str, MellonParam]] - model_inputs: List[Union[str, MellonParam]] - outputs: List[Union[str, MellonParam]] - blocks_names: list[str] - node_type: str - config_name = "mellon_config.json" - - def __post_init__(self): - if isinstance(self.inputs, list): - self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS) - if isinstance(self.model_inputs, list): - self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS) - if isinstance(self.outputs, list): - self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS) - - @staticmethod - def _resolve_params_list( - params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]] - ) -> Dict[str, Dict[str, Any]]: - def _resolve_param( - param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]] - ) -> Tuple[str, Dict[str, Any]]: - if isinstance(param, str): - if param not in default_params_map: - raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead") - return param, default_params_map[param].copy() - elif isinstance(param, MellonParam): - param_dict = param.to_dict() - param_name = param_dict.pop("name") - return param_name, param_dict - else: - raise ValueError( - f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead" - ) - - resolved = {} - for p in params: - logger.info(f" Resolving param: {p}") - name, cfg = _resolve_param(p, default_map) - if name in resolved: - raise ValueError(f"Duplicate param '{name}'") - resolved[name] = cfg - return resolved + return {k: v for k, v in data.items() if v is not None and k != "name"} @classmethod - @validate_hf_hub_args - def load_mellon_config( + def image(cls) -> "MellonParam": + return cls(name="image", label="Image", type="image", display="input") + + @classmethod + def images(cls) -> "MellonParam": + return cls(name="images", label="Images", type="image", display="output") + + @classmethod + def control_image(cls, display: str = "input") -> "MellonParam": + return cls(name="control_image", label="Control Image", type="image", display=display) + + @classmethod + def latents(cls, display: str = "input") -> "MellonParam": + return cls(name="latents", label="Latents", type="latents", display=display) + + @classmethod + def image_latents(cls, display: str = "input") -> "MellonParam": + return cls(name="image_latents", label="Image Latents", type="latents", display=display) + + @classmethod + def image_latents_with_strength(cls) -> "MellonParam": + return cls( + name="image_latents", + label="Image Latents", + type="latents", + display="input", + onChange={"false": ["height", "width"], "true": ["strength"]}, + ) + + @classmethod + def latents_preview(cls) -> "MellonParam": + """ + `Latents Preview` is a special output parameter that is used to preview the latents in the UI. + """ + return cls(name="latents_preview", label="Latents Preview", type="latent", display="output") + + @classmethod + def embeddings(cls, display: str = "output") -> "MellonParam": + return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display) + + @classmethod + def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": + return cls( + name="controlnet_conditioning_scale", + label="Controlnet Conditioning Scale", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def control_guidance_start(cls, default: float = 0.0) -> "MellonParam": + return cls( + name="control_guidance_start", + label="Control Guidance Start", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def control_guidance_end(cls, default: float = 1.0) -> "MellonParam": + return cls( + name="control_guidance_end", + label="Control Guidance End", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + ) + + @classmethod + def prompt(cls, default: str = "") -> "MellonParam": + return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea") + + @classmethod + def negative_prompt(cls, default: str = "") -> "MellonParam": + return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea") + + @classmethod + def strength(cls, default: float = 0.5) -> "MellonParam": + return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01) + + @classmethod + def guidance_scale(cls, default: float = 5.0) -> "MellonParam": + return cls( + name="guidance_scale", + label="Guidance Scale", + type="float", + display="slider", + default=default, + min=1.0, + max=30.0, + step=0.1, + ) + + @classmethod + def height(cls, default: int = 1024) -> "MellonParam": + return cls(name="height", label="Height", type="int", default=default, min=64, step=8) + + @classmethod + def width(cls, default: int = 1024) -> "MellonParam": + return cls(name="width", label="Width", type="int", default=default, min=64, step=8) + + @classmethod + def seed(cls, default: int = 0) -> "MellonParam": + return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random") + + @classmethod + def num_inference_steps(cls, default: int = 25) -> "MellonParam": + return cls( + name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider" + ) + + @classmethod + def vae(cls) -> "MellonParam": + """ + VAE model info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input") + + @classmethod + def unet(cls) -> "MellonParam": + """ + Denoising model (UNet/Transformer) info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input") + + @classmethod + def scheduler(cls) -> "MellonParam": + """ + Scheduler model info dict. + + Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual + scheduler. + """ + return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input") + + @classmethod + def controlnet(cls) -> "MellonParam": + """ + ControlNet model info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input") + + @classmethod + def text_encoders(cls) -> "MellonParam": + """ + Dict of text encoder model info dicts. + + Structure: { + 'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...}, + 'repo_id': '...' + } Use components.get_one(model_id) to retrieve each model. + """ + return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input") + + @classmethod + def controlnet_bundle(cls, display: str = "input") -> "MellonParam": + """ + ControlNet bundle containing model info and processed control inputs. + + Structure: { + 'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed + control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise + blocks + } + + Output from Controlnet node, input to Denoise node. + """ + return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display) + + @classmethod + def ip_adapter(cls) -> "MellonParam": + return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input") + + @classmethod + def guider(cls) -> "MellonParam": + return cls( + name="guider", + label="Guider", + type="custom_guider", + display="input", + onChange={False: ["guidance_scale"], True: []}, + ) + + @classmethod + def doc(cls) -> "MellonParam": + return cls(name="doc", label="Doc", type="string", display="output") + + +def mark_required(label: str, marker: str = " *") -> str: + """Add required marker to label if not already present.""" + if label.endswith(marker): + return label + return f"{label}{marker}" + + +def node_spec_to_mellon_dict(node_spec: Dict[str, Any], node_type: str) -> Dict[str, Any]: + """ + Convert a node spec dict into Mellon format. + + A node spec is how we define a Mellon diffusers node in code. This function converts it into the `params` map + format that Mellon UI expects. + + The `params` map is a dict where keys are parameter names and values are UI configuration: + ```python + {"seed": {"label": "Seed", "type": "int", "default": 0}} + ``` + + For Modular Mellon nodes, we need to distinguish: + - `inputs`: Pipeline inputs (e.g., seed, prompt, image) + - `model_inputs`: Model components (e.g., unet, vae, scheduler) + - `outputs`: Node outputs (e.g., latents, images) + + The node spec also includes: + - `required_inputs` / `required_model_inputs`: Which params are required (marked with *) + - `block_name`: The modular pipeline block this node corresponds to on backend + + We provide factory methods for common parameters (e.g., `MellonParam.seed()`, `MellonParam.unet()`) so you don't + have to manually specify all the UI configuration. + + Args: + node_spec: Dict with `inputs`, `model_inputs`, `outputs` (lists of MellonParam), + plus `required_inputs`, `required_model_inputs`, `block_name`. + node_type: The node type string (e.g., "denoise", "controlnet") + + Returns: + Dict with: + - `params`: Flat dict of all params in Mellon UI format + - `input_names`: List of input parameter names + - `model_input_names`: List of model input parameter names + - `output_names`: List of output parameter names + - `block_name`: The backend block name + - `node_type`: The node type + + Example: + ```python + node_spec = { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + } + + result = node_spec_to_mellon_dict(node_spec, "denoise") + # Returns: + # { + # "params": { + # "seed": {"label": "Seed", "type": "int", "default": 0}, + # "prompt": {"label": "Prompt *", "type": "string", "default": ""}, # * marks required + # "unet": {"label": "Denoise Model *", "type": "diffusers_auto_model", "display": "input"}, + # "latents": {"label": "Latents", "type": "latents", "display": "output"}, + # }, + # "input_names": ["seed", "prompt"], + # "model_input_names": ["unet"], + # "output_names": ["latents"], + # "block_name": "denoise", + # "node_type": "denoise", + # } + ``` + """ + params = {} + input_names = [] + model_input_names = [] + output_names = [] + + required_inputs = node_spec.get("required_inputs", []) + required_model_inputs = node_spec.get("required_model_inputs", []) + + # Process inputs + for p in node_spec.get("inputs", []): + param_dict = p.to_dict() + if p.name in required_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + input_names.append(p.name) + + # Process model_inputs + for p in node_spec.get("model_inputs", []): + param_dict = p.to_dict() + if p.name in required_model_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + model_input_names.append(p.name) + + # Process outputs + for p in node_spec.get("outputs", []): + params[p.name] = p.to_dict() + output_names.append(p.name) + + return { + "params": params, + "input_names": input_names, + "model_input_names": model_input_names, + "output_names": output_names, + "block_name": node_spec.get("block_name"), + "node_type": node_type, + } + + +class MellonPipelineConfig: + """ + Configuration for an entire Mellon pipeline containing multiple nodes. + + Accepts node specs as dicts with inputs/model_inputs/outputs lists of MellonParam, converts them to Mellon-ready + format, and handles save/load to Hub. + + Example: + ```python + config = MellonPipelineConfig( + node_specs={ + "denoise": { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + }, + "decoder": { + "inputs": [MellonParam.latents(display="input")], + "outputs": [MellonParam.images()], + "block_name": "decoder", + }, + }, + label="My Pipeline", + default_repo="user/my-pipeline", + default_dtype="float16", + ) + + # Access Mellon format dict + denoise = config.node_params["denoise"] + input_names = denoise["input_names"] + params = denoise["params"] + + # Save to Hub + config.save("./my_config", push_to_hub=True, repo_id="user/my-pipeline") + + # Load from Hub + loaded = MellonPipelineConfig.load("user/my-pipeline") + ``` + """ + + config_name = "mellon_pipeline_config.json" + + def __init__( + self, + node_specs: Dict[str, Optional[Dict[str, Any]]], + label: str = "", + default_repo: str = "", + default_dtype: str = "", + ): + """ + Args: + node_specs: Dict mapping node_type to node spec or None. + Node spec has: inputs, model_inputs, outputs, required_inputs, required_model_inputs, + block_name (all optional) + label: Human-readable label for the pipeline + default_repo: Default HuggingFace repo for this pipeline + default_dtype: Default dtype (e.g., "float16", "bfloat16") + """ + # Convert all node specs to Mellon format immediately + self.node_params = {} + for node_type, spec in node_specs.items(): + if spec is None: + self.node_params[node_type] = None + else: + self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type) + + self.label = label + self.default_repo = default_repo + self.default_dtype = default_dtype + + def __repr__(self) -> str: + node_types = list(self.node_params.keys()) + return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})" + + def to_dict(self) -> Dict[str, Any]: + """Convert to a JSON-serializable dictionary.""" + return { + "label": self.label, + "default_repo": self.default_repo, + "default_dtype": self.default_dtype, + "node_params": self.node_params, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MellonPipelineConfig": + """ + Create from a dictionary (loaded from JSON). + + Note: The mellon_params are already in Mellon format when loading from JSON. + """ + instance = cls.__new__(cls) + instance.node_params = data.get("node_params", {}) + instance.label = data.get("label", "") + instance.default_repo = data.get("default_repo", "") + instance.default_dtype = data.get("default_dtype", "") + return instance + + def to_json_string(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=False) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """Save to a JSON file.""" + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + @classmethod + def from_json_file(cls, json_file_path: Union[str, os.PathLike]) -> "MellonPipelineConfig": + """Load from a JSON file.""" + with open(json_file_path, "r", encoding="utf-8") as reader: + data = json.load(reader) + return cls.from_dict(data) + + def save(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """Save the pipeline config to a directory.""" + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + output_path = os.path.join(save_directory, self.config_name) + self.to_json_file(output_path) + logger.info(f"Pipeline config saved to {output_path}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + subfolder = kwargs.pop("subfolder", None) + + upload_folder( + repo_id=repo_id, + folder_path=save_directory, + token=token, + commit_message=commit_message or "Upload MellonPipelineConfig", + create_pr=create_pr, + path_in_repo=subfolder, + ) + logger.info(f"Pipeline config pushed to hub: {repo_id}") + + @classmethod + def load( cls, pretrained_model_name_or_path: Union[str, os.PathLike], - return_unused_kwargs=False, - return_commit_hash=False, **kwargs, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - r""" - Load a model or scheduler configuration. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with - [`~ConfigMixin.save_config`]. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - return_unused_kwargs (`bool`, *optional*, defaults to `False): - Whether unused keyword arguments of the config are returned. - return_commit_hash (`bool`, *optional*, defaults to `False): - Whether the `commit_hash` of the loaded configuration are returned. - - Returns: - `dict`: - A dictionary of all the parameters stored in a JSON configuration file. - - """ + ) -> "MellonPipelineConfig": + """Load a pipeline config from a local path or Hugging Face Hub.""" cache_dir = kwargs.pop("cache_dir", None) local_dir = kwargs.pop("local_dir", None) local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") @@ -450,27 +526,18 @@ class MellonNodeConfig(PushToHubMixin): token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from " - "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + if not os.path.isfile(config_file): + raise EnvironmentError(f"No file named {cls.config_name} found in {pretrained_model_name_or_path}") else: try: - # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, @@ -480,6 +547,7 @@ class MellonNodeConfig(PushToHubMixin): local_files_only=local_files_only, token=token, revision=revision, + subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) @@ -519,245 +587,8 @@ class MellonNodeConfig(PushToHubMixin): f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) + try: - with open(config_file, "r", encoding="utf-8") as reader: - text = reader.read() - config_dict = json.loads(text) - - commit_hash = extract_commit_hash(config_file) + return cls.from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") - - if not (return_unused_kwargs or return_commit_hash): - return config_dict - - outputs = (config_dict,) - - if return_unused_kwargs: - outputs += (kwargs,) - - if return_commit_hash: - outputs += (commit_hash,) - - return outputs - - def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save the Mellon node definition to a JSON file. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file is saved (will be created if it does not exist). - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - # If we save using the predefined names, we can load using `from_config` - output_config_file = os.path.join(save_directory, self.config_name) - - self.to_json_file(output_config_file) - logger.info(f"Mellon node definition saved in {output_config_file}") - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - private = kwargs.pop("private", None) - create_pr = kwargs.pop("create_pr", False) - token = kwargs.pop("token", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - subfolder = kwargs.pop("subfolder", None) - - self._upload_folder( - save_directory, - repo_id, - token=token, - commit_message=commit_message, - create_pr=create_pr, - subfolder=subfolder, - ) - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save the Mellon schema dictionary to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file to save a configuration instance's parameters. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string of the Mellon schema dict. - - Args: - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - - mellon_dict = self.to_mellon_dict() - return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n" - - def to_mellon_dict(self) -> Dict[str, Any]: - """Return a JSON-serializable dict focusing on the Mellon schema fields only. - - params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}. - """ - # inputs/model_inputs/outputs are already normalized dicts - merged_params = {} - merged_params.update(self.inputs or {}) - merged_params.update(self.model_inputs or {}) - merged_params.update(self.outputs or {}) - - return { - "node_type": self.node_type, - "blocks_names": self.blocks_names, - "params": merged_params, - } - - @classmethod - def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig": - """Create a config from a Mellon schema dict produced by to_mellon_dict(). - - Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from - MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by - default. - """ - flat_params = mellon_dict.get("params", {}) - - inputs: Dict[str, Any] = {} - model_inputs: Dict[str, Any] = {} - outputs: Dict[str, Any] = {} - - for param_name, param_dict in flat_params.items(): - if param_dict.get("display", "") == "output": - outputs[param_name] = param_dict - elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"): - model_inputs[param_name] = param_dict - else: - inputs[param_name] = param_dict - - return cls( - inputs=inputs, - model_inputs=model_inputs, - outputs=outputs, - blocks_names=mellon_dict.get("blocks_names", []), - node_type=mellon_dict.get("node_type"), - ) - - # YiYi Notes: not used yet - @classmethod - def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig": - """ - Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type, - use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs. - """ - if node_type not in NODE_TYPE_PARAMS_MAP: - raise ValueError(f"Node type {node_type} not supported") - - blocks_names = list(blocks.sub_blocks.keys()) - - default_node_config = NODE_TYPE_PARAMS_MAP[node_type] - inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", []) - model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", []) - outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", []) - - for required_input_name in blocks.required_inputs: - if required_input_name not in inputs_list: - inputs_list.append( - MellonParam( - name=required_input_name, label=required_input_name, type=required_input_name, display="input" - ) - ) - - for component_spec in blocks.expected_components: - if component_spec.name not in model_inputs_list: - model_inputs_list.append( - MellonParam( - name=component_spec.name, - label=component_spec.name, - type="diffusers_auto_model", - display="input", - ) - ) - - return cls( - inputs=inputs_list, - model_inputs=model_inputs_list, - outputs=outputs_list, - blocks_names=blocks_names, - node_type=node_type, - ) - - -# Minimal modular registry for Mellon node configs -class ModularMellonNodeRegistry: - """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig.""" - - def __init__(self): - self._registry = {} - self._initialized = False - - def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]): - if not self._initialized: - _initialize_registry(self) - self._registry[pipeline_cls] = node_params - - def get(self, pipeline_cls: type) -> MellonNodeConfig: - if not self._initialized: - _initialize_registry(self) - return self._registry.get(pipeline_cls, None) - - def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]: - if not self._initialized: - _initialize_registry(self) - return self._registry - - -def _register_preset_node_types( - pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry -): - """Register all node-type presets for a given pipeline class from a params map.""" - node_configs = {} - for node_type, spec in params_map.items(): - node_config = MellonNodeConfig( - inputs=spec.get("inputs", []), - model_inputs=spec.get("model_inputs", []), - outputs=spec.get("outputs", []), - blocks_names=spec.get("block_names", []), - node_type=node_type, - ) - node_configs[node_type] = node_config - registry.register(pipeline_cls, node_configs) - - -def _initialize_registry(registry: ModularMellonNodeRegistry): - """Initialize the registry and register all available pipeline configs.""" - print("Initializing registry") - - registry._initialized = True - - try: - from .qwenimage.modular_pipeline import QwenImageModularPipeline - from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register QwenImageModularPipeline") - - try: - from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline - from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register StableDiffusionXLModularPipeline") + raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.") diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 17c0117bff..c5fa4cf992 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -501,15 +501,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): @property def input_names(self) -> List[str]: - return [input_param.name for input_param in self.inputs] + return [input_param.name for input_param in self.inputs if input_param.name is not None] @property def intermediate_output_names(self) -> List[str]: - return [output_param.name for output_param in self.intermediate_outputs] + return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None] @property def output_names(self) -> List[str]: - return [output_param.name for output_param in self.outputs] + return [output_param.name for output_param in self.outputs if output_param.name is not None] + + @property + def component_names(self) -> List[str]: + return [component.name for component in self.expected_components] @property def doc(self): @@ -1525,10 +1529,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): if blocks is None: if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") - elif config_dict is not None: - blocks_class_name = self.get_default_blocks_name(config_dict) else: - blocks_class_name = None + blocks_class_name = self.get_default_blocks_name(config_dict) if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") blocks_class = getattr(diffusers_module, blocks_class_name) @@ -1625,7 +1627,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): return None, config_dict except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") + raise EnvironmentError( + f"Failed to load config from '{pretrained_model_name_or_path}'. " + f"Could not find or load 'modular_model_index.json' or 'model_index.json'." + ) from e return None, None @@ -2550,7 +2555,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: state.set(name, passed_kwargs.pop(name), kwargs_type) - elif name not in state.values: + elif kwargs_type is not None and kwargs_type in passed_kwargs: + kwargs_dict = passed_kwargs.pop(kwargs_type) + for k, v in kwargs_dict.items(): + state.set(k, v, kwargs_type) + elif name is not None and name not in state.values: state.set(name, default, kwargs_type) # Warn about unexpected inputs diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 26417162de..6e145f1855 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -30,6 +30,47 @@ from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier logger = logging.get_logger(__name__) +class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + return components + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor + ) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageDecoderStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -41,7 +82,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks): def expected_components(self) -> List[ComponentSpec]: components = [ ComponentSpec("vae", AutoencoderKLQwenImage), - ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), ] return components @@ -49,8 +89,6 @@ class QwenImageDecoderStep(ModularPipelineBlocks): @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="height", required=True), - InputParam(name="width", required=True), InputParam( name="latents", required=True, @@ -74,10 +112,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks): block_state = self.get_block_state(state) # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular - vae_scale_factor = components.vae_scale_factor - block_state.latents = components.pachifier.unpack_latents( - block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor - ) + if block_state.latents.ndim == 4: + block_state.latents = block_state.latents.unsqueeze(dim=1) + elif block_state.latents.ndim != 5: + raise ValueError( + f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step." + ) block_state.latents = block_state.latents.to(components.vae.dtype) latents_mean = ( diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 55a7ae328f..dcce0cab5d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -26,7 +26,12 @@ from .before_denoise import ( QwenImageSetTimestepsStep, QwenImageSetTimestepsWithStrengthStep, ) -from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) from .denoise import ( QwenImageControlNetDenoiseStep, QwenImageDenoiseStep, @@ -92,6 +97,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -205,6 +211,7 @@ INPAINT_BLOCKS = InsertableDict( ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageInpaintDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageInpaintDecodeStep()), ] ) @@ -264,6 +271,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict( ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), ("prepare_rope_inputs", QwenImageRoPEInputsStep()), ("denoise", QwenImageDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -529,8 +537,16 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): QwenImageAutoBeforeDenoiseStep, QwenImageOptionalControlNetBeforeDenoiseStep, QwenImageAutoDenoiseStep, + QwenImageAfterDenoiseStep, + ] + block_names = [ + "input", + "controlnet_input", + "before_denoise", + "controlnet_before_denoise", + "denoise", + "after_denoise", ] - block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"] @property def description(self): @@ -653,6 +669,7 @@ EDIT_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -702,6 +719,7 @@ EDIT_INPAINT_BLOCKS = InsertableDict( ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), ("denoise", QwenImageEditInpaintDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageInpaintDecodeStep()), ] ) @@ -841,8 +859,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): QwenImageEditAutoInputStep, QwenImageEditAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, + QwenImageAfterDenoiseStep, ] - block_names = ["input", "before_denoise", "denoise"] + block_names = ["input", "before_denoise", "denoise", "after_denoise"] @property def description(self): @@ -954,6 +973,7 @@ EDIT_PLUS_BLOCKS = InsertableDict( ("set_timesteps", QwenImageSetTimestepsStep()), ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), + ("after_denoise", QwenImageAfterDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) @@ -1037,8 +1057,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): QwenImageEditPlusAutoInputStep, QwenImageEditPlusAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, + QwenImageAfterDenoiseStep, ] - block_names = ["input", "before_denoise", "denoise"] + block_names = ["input", "before_denoise", "denoise", "after_denoise"] @property def description(self): diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py deleted file mode 100644 index 3230ece68a..0000000000 --- a/src/diffusers/modular_pipelines/qwenimage/node_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# mellon nodes -QwenImage_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": ["controlnet_vae_encoder"], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - "controlnet", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - }, -} diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py deleted file mode 100644 index 3e788bf947..0000000000 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -SDXL_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": [None], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - "block_names": ["vae_encoder"], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - "block_names": ["text_encoder"], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - "block_names": ["decode"], - }, -} diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py index ec815f77ad..3d5a00a9df 100644 --- a/src/diffusers/modular_pipelines/z_image/denoise.py +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -129,6 +129,10 @@ class ZImageLoopDenoiser(ModularPipelineBlocks): type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), ] guider_input_names = [] uncond_guider_input_names = [] diff --git a/src/diffusers/modular_pipelines/z_image/modular_blocks.py b/src/diffusers/modular_pipelines/z_image/modular_blocks.py index a7c520301a..a54baeccaf 100644 --- a/src/diffusers/modular_pipelines/z_image/modular_blocks.py +++ b/src/diffusers/modular_pipelines/z_image/modular_blocks.py @@ -119,7 +119,7 @@ class ZImageAutoDenoiseStep(AutoPipelineBlocks): class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): block_classes = [ZImageVaeImageEncoderStep] - block_names = ["vae_image_encoder"] + block_names = ["vae_encoder"] block_trigger_inputs = ["image"] @property @@ -137,7 +137,7 @@ class ZImageAutoBlocks(SequentialPipelineBlocks): ZImageAutoDenoiseStep, ZImageVaeDecoderStep, ] - block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] @property def description(self) -> str: @@ -162,7 +162,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( IMAGE2IMAGE_BLOCKS = InsertableDict( [ ("text_encoder", ZImageTextEncoderStep), - ("vae_image_encoder", ZImageVaeImageEncoderStep), + ("vae_encoder", ZImageVaeImageEncoderStep), ("input", ZImageTextInputStep), ("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])), ("prepare_latents", ZImagePrepareLatentsStep), @@ -178,7 +178,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict( AUTO_BLOCKS = InsertableDict( [ ("text_encoder", ZImageTextEncoderStep), - ("vae_image_encoder", ZImageAutoVaeImageEncoderStep), + ("vae_encoder", ZImageAutoVaeImageEncoderStep), ("denoise", ZImageAutoDenoiseStep), ("decode", ZImageVaeDecoderStep), ] From 262ce19bff6b19e38aed3519fc9eb2d90d24f87a Mon Sep 17 00:00:00 2001 From: MatrixTeam-AI Date: Sat, 20 Dec 2025 07:10:40 +0800 Subject: [PATCH 3/3] Feature: Add Mambo-G Guidance as Guider (#12862) * Feature: Add Mambo-G Guidance to Qwen-Image Pipeline * change to guider implementation * fix copied code residual * Update src/diffusers/guiders/magnitude_aware_guidance.py * Apply style fixes --------- Co-authored-by: Pscgylotti Co-authored-by: YiYi Xu Co-authored-by: github-actions[bot] --- src/diffusers/guiders/__init__.py | 1 + .../guiders/magnitude_aware_guidance.py | 159 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 src/diffusers/guiders/magnitude_aware_guidance.py diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 4e53c373c4..58ad0c211b 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -25,6 +25,7 @@ if is_torch_available(): from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .guider_utils import BaseGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py new file mode 100644 index 0000000000..b81cf0d3a1 --- /dev/null +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class MagnitudeAwareGuidance(BaseGuidance): + """ + Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442 + + Args: + guidance_scale (`float`, defaults to `10.0`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + alpha (`float`, defaults to `8.0`): + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of + guidance scale when the magnitude of the guidance update is large. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 10.0, + alpha: float = 8.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + enabled: bool = True, + ): + super().__init__(start, stop, enabled) + + self.guidance_scale = guidance_scale + self.alpha = alpha + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch(data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: + pred = None + + if not self._is_mambo_g_enabled(): + pred = pred_cond + else: + pred = mambo_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.alpha, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_mambo_g_enabled(): + num_conditions += 1 + return num_conditions + + def _is_mambo_g_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def mambo_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + alpha: float = 8.0, + use_original_formulation: bool = False, +): + dim = list(range(1, len(pred_cond.shape))) + diff = pred_cond - pred_uncond + ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) + guidance_scale_final = ( + guidance_scale * torch.exp(-alpha * ratio) + if use_original_formulation + else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio) + ) + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale_final * diff + + return pred