diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index ba1d1542b2..b4e9463992 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -4,7 +4,7 @@ import os # Simple typed wrapper for parameter overrides from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional, Union, List +from typing import Any, Dict, List, Optional, Union from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import ( @@ -42,7 +42,7 @@ class MellonParam: fieldOptions: Optional[Dict[str, Any]] = None onChange: Any = None onSignal: Any = None - required_block_params: Optional[Union[str, List[str]]] = None + required_block_params: Optional[Union[str, List[str]]] = None def to_dict(self) -> Dict[str, Any]: """Convert to dict for Mellon schema, excluding None values and name.""" @@ -59,7 +59,13 @@ class MellonParam: @classmethod def control_image(cls, display: str = "input") -> "MellonParam": - return cls(name="control_image", label="Control Image", type="image", display=display, required_block_params=["control_image"]) + return cls( + name="control_image", + label="Control Image", + type="image", + display=display, + required_block_params=["control_image"], + ) @classmethod def latents(cls, display: str = "input") -> "MellonParam": @@ -67,11 +73,23 @@ class MellonParam: @classmethod def image_latents(cls, display: str = "input") -> "MellonParam": - return cls(name="image_latents", label="Image Latents", type="latents", display=display, required_block_params=["image_latents"]) + return cls( + name="image_latents", + label="Image Latents", + type="latents", + display=display, + required_block_params=["image_latents"], + ) @classmethod def first_frame_latents(cls, display: str = "input") -> "MellonParam": - return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display, required_block_params=["first_frame_latents"]) + return cls( + name="first_frame_latents", + label="First Frame Latents", + type="latents", + display=display, + required_block_params=["first_frame_latents"], + ) @classmethod def image_latents_with_strength(cls) -> "MellonParam": @@ -97,7 +115,13 @@ class MellonParam: @classmethod def image_embeds(cls, display: str = "output") -> "MellonParam": - return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display, required_block_params=["image_embeds"]) + return cls( + name="image_embeds", + label="Image Embeddings", + type="image_embeds", + display=display, + required_block_params=["image_embeds"], + ) @classmethod def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": @@ -140,15 +164,38 @@ class MellonParam: @classmethod def prompt(cls, default: str = "") -> "MellonParam": - return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea", required_block_params=["prompt"]) + return cls( + name="prompt", + label="Prompt", + type="string", + default=default, + display="textarea", + required_block_params=["prompt"], + ) @classmethod def negative_prompt(cls, default: str = "") -> "MellonParam": - return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea", required_block_params=["negative_prompt"]) + return cls( + name="negative_prompt", + label="Negative Prompt", + type="string", + default=default, + display="textarea", + required_block_params=["negative_prompt"], + ) @classmethod def strength(cls, default: float = 0.5) -> "MellonParam": - return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01, required_block_params=["strength"]) + return cls( + name="strength", + label="Strength", + type="float", + default=default, + min=0.0, + max=1.0, + step=0.01, + required_block_params=["strength"], + ) @classmethod def guidance_scale(cls, default: float = 5.0) -> "MellonParam": @@ -165,29 +212,73 @@ class MellonParam: @classmethod def height(cls, default: int = 1024) -> "MellonParam": - return cls(name="height", label="Height", type="int", default=default, min=64, step=8, required_block_params=["height"]) + return cls( + name="height", + label="Height", + type="int", + default=default, + min=64, + step=8, + required_block_params=["height"], + ) @classmethod def width(cls, default: int = 1024) -> "MellonParam": - return cls(name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]) + return cls( + name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"] + ) @classmethod def seed(cls, default: int = 0) -> "MellonParam": - return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random", required_block_params=["generator"]) + return cls( + name="seed", + label="Seed", + type="int", + default=default, + min=0, + max=4294967295, + display="random", + required_block_params=["generator"], + ) @classmethod def num_inference_steps(cls, default: int = 25) -> "MellonParam": return cls( - name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider", required_block_params=["num_inference_steps"] + name="num_inference_steps", + label="Steps", + type="int", + default=default, + min=1, + max=100, + display="slider", + required_block_params=["num_inference_steps"], ) @classmethod def num_frames(cls, default: int = 81) -> "MellonParam": - return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider", required_block_params=["num_frames"]) + return cls( + name="num_frames", + label="Frames", + type="int", + default=default, + min=1, + max=480, + display="slider", + required_block_params=["num_frames"], + ) @classmethod def layers(cls, default: int = 4) -> "MellonParam": - return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider", required_block_params=["layers"]) + return cls( + name="layers", + label="Layers", + type="int", + default=default, + min=1, + max=10, + display="slider", + required_block_params=["layers"], + ) @classmethod def videos(cls) -> "MellonParam": @@ -201,7 +292,9 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]) + return cls( + name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"] + ) @classmethod def image_encoder(cls) -> "MellonParam": @@ -211,7 +304,13 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input", required_block_params=["image_encoder"]) + return cls( + name="image_encoder", + label="Image Encoder", + type="diffusers_auto_model", + display="input", + required_block_params=["image_encoder"], + ) @classmethod def unet(cls) -> "MellonParam": @@ -241,7 +340,13 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input", required_block_params=["controlnet"]) + return cls( + name="controlnet", + label="ControlNet Model", + type="diffusers_auto_model", + display="input", + required_block_params=["controlnet"], + ) @classmethod def text_encoders(cls) -> "MellonParam": @@ -253,7 +358,13 @@ class MellonParam: 'repo_id': '...' } Use components.get_one(model_id) to retrieve each model. """ - return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input", required_block_params=["text_encoder"]) + return cls( + name="text_encoders", + label="Text Encoders", + type="diffusers_auto_models", + display="input", + required_block_params=["text_encoder"], + ) @classmethod def controlnet_bundle(cls, display: str = "input") -> "MellonParam": @@ -268,7 +379,13 @@ class MellonParam: Output from Controlnet node, input to Denoise node. """ - return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display, required_block_params="controlnet_image") + return cls( + name="controlnet_bundle", + label="ControlNet", + type="custom_controlnet", + display=display, + required_block_params="controlnet_image", + ) @classmethod def ip_adapter(cls) -> "MellonParam": @@ -561,7 +678,9 @@ class MellonPipelineConfig: return params def __repr__(self) -> str: - lines = [f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"] + lines = [ + f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})" + ] for node_type, spec in self.node_specs.items(): if spec is None: lines.append(f" {node_type}: None") @@ -575,7 +694,6 @@ class MellonPipelineConfig: lines.append(f" outputs: {outputs}") return "\n".join(lines) - def to_dict(self) -> Dict[str, Any]: """Convert to a JSON-serializable dictionary.""" return { @@ -729,7 +847,7 @@ class MellonPipelineConfig: @classmethod def from_blocks( cls, - blocks: "ModularPipelineBlocks", + blocks, template: Dict[str, Optional[Dict[str, Any]]] = None, label: str = "", default_repo: str = "", @@ -740,26 +858,44 @@ class MellonPipelineConfig: """ if template is None: template = DEFAULT_NODE_SPECS - + sub_block_map = dict(blocks.sub_blocks) - + def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]: """Filter template spec params based on what the block actually supports.""" block_input_names = set(block.input_names) block_output_names = set(block.intermediate_output_names) block_component_names = set(block.component_names) - - filtered_inputs = [p for p in template_spec.get("inputs", []) if p.required_block_params is None or all(name in block_input_names for name in p.required_block_params)] - filtered_model_inputs = [p for p in template_spec.get("model_inputs", []) if p.required_block_params is None or all(name in block_component_names for name in p.required_block_params)] - filtered_outputs = [p for p in template_spec.get("outputs", []) if p.required_block_params is None or all(name in block_output_names for name in p.required_block_params)] - + + filtered_inputs = [ + p + for p in template_spec.get("inputs", []) + if p.required_block_params is None + or all(name in block_input_names for name in p.required_block_params) + ] + filtered_model_inputs = [ + p + for p in template_spec.get("model_inputs", []) + if p.required_block_params is None + or all(name in block_component_names for name in p.required_block_params) + ] + filtered_outputs = [ + p + for p in template_spec.get("outputs", []) + if p.required_block_params is None + or all(name in block_output_names for name in p.required_block_params) + ] + filtered_input_names = {p.name for p in filtered_inputs} filtered_model_input_names = {p.name for p in filtered_model_inputs} - - filtered_required_inputs = [r for r in template_spec.get("required_inputs", []) if r in filtered_input_names] - filtered_required_model_inputs = [r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names] - + filtered_required_inputs = [ + r for r in template_spec.get("required_inputs", []) if r in filtered_input_names + ] + filtered_required_model_inputs = [ + r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names + ] + return { "inputs": filtered_inputs, "model_inputs": filtered_model_inputs, @@ -768,24 +904,24 @@ class MellonPipelineConfig: "required_model_inputs": filtered_required_model_inputs, "block_name": template_spec.get("block_name"), } - + # Build node specs node_specs = {} for node_type, template_spec in template.items(): if template_spec is None: node_specs[node_type] = None continue - + block_name = template_spec.get("block_name") if block_name is None or block_name not in sub_block_map: node_specs[node_type] = None continue - + node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name]) - + return cls( node_specs=node_specs, label=label or getattr(blocks, "model_name", ""), default_repo=default_repo, default_dtype=default_dtype, - ) \ No newline at end of file + )