diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index aae0427196..ba1d1542b2 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -4,7 +4,7 @@ import os # Simple typed wrapper for parameter overrides from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import ( @@ -42,35 +42,36 @@ class MellonParam: fieldOptions: Optional[Dict[str, Any]] = None onChange: Any = None onSignal: Any = None + required_block_params: Optional[Union[str, List[str]]] = None def to_dict(self) -> Dict[str, Any]: """Convert to dict for Mellon schema, excluding None values and name.""" data = asdict(self) - return {k: v for k, v in data.items() if v is not None and k != "name"} + return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")} @classmethod def image(cls) -> "MellonParam": - return cls(name="image", label="Image", type="image", display="input") + return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"]) @classmethod def images(cls) -> "MellonParam": - return cls(name="images", label="Images", type="image", display="output") + return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"]) @classmethod def control_image(cls, display: str = "input") -> "MellonParam": - return cls(name="control_image", label="Control Image", type="image", display=display) + return cls(name="control_image", label="Control Image", type="image", display=display, required_block_params=["control_image"]) @classmethod def latents(cls, display: str = "input") -> "MellonParam": - return cls(name="latents", label="Latents", type="latents", display=display) + return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"]) @classmethod def image_latents(cls, display: str = "input") -> "MellonParam": - return cls(name="image_latents", label="Image Latents", type="latents", display=display) + return cls(name="image_latents", label="Image Latents", type="latents", display=display, required_block_params=["image_latents"]) @classmethod def first_frame_latents(cls, display: str = "input") -> "MellonParam": - return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display) + return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display, required_block_params=["first_frame_latents"]) @classmethod def image_latents_with_strength(cls) -> "MellonParam": @@ -80,6 +81,7 @@ class MellonParam: type="latents", display="input", onChange={"false": ["height", "width"], "true": ["strength"]}, + required_block_params=["image_latents", "strength"], ) @classmethod @@ -95,7 +97,7 @@ class MellonParam: @classmethod def image_embeds(cls, display: str = "output") -> "MellonParam": - return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display) + return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display, required_block_params=["image_embeds"]) @classmethod def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": @@ -107,6 +109,7 @@ class MellonParam: min=0.0, max=1.0, step=0.01, + required_block_params=["controlnet_conditioning_scale"], ) @classmethod @@ -119,6 +122,7 @@ class MellonParam: min=0.0, max=1.0, step=0.01, + required_block_params=["control_guidance_start"], ) @classmethod @@ -131,19 +135,20 @@ class MellonParam: min=0.0, max=1.0, step=0.01, + required_block_params=["control_guidance_end"], ) @classmethod def prompt(cls, default: str = "") -> "MellonParam": - return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea") + return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea", required_block_params=["prompt"]) @classmethod def negative_prompt(cls, default: str = "") -> "MellonParam": - return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea") + return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea", required_block_params=["negative_prompt"]) @classmethod def strength(cls, default: float = 0.5) -> "MellonParam": - return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01) + return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01, required_block_params=["strength"]) @classmethod def guidance_scale(cls, default: float = 5.0) -> "MellonParam": @@ -160,33 +165,33 @@ class MellonParam: @classmethod def height(cls, default: int = 1024) -> "MellonParam": - return cls(name="height", label="Height", type="int", default=default, min=64, step=8) + return cls(name="height", label="Height", type="int", default=default, min=64, step=8, required_block_params=["height"]) @classmethod def width(cls, default: int = 1024) -> "MellonParam": - return cls(name="width", label="Width", type="int", default=default, min=64, step=8) + return cls(name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]) @classmethod def seed(cls, default: int = 0) -> "MellonParam": - return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random") + return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random", required_block_params=["generator"]) @classmethod def num_inference_steps(cls, default: int = 25) -> "MellonParam": return cls( - name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider" + name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider", required_block_params=["num_inference_steps"] ) @classmethod def num_frames(cls, default: int = 81) -> "MellonParam": - return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider") + return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider", required_block_params=["num_frames"]) @classmethod def layers(cls, default: int = 4) -> "MellonParam": - return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider") + return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider", required_block_params=["layers"]) @classmethod def videos(cls) -> "MellonParam": - return cls(name="videos", label="Videos", type="video", display="output") + return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"]) @classmethod def vae(cls) -> "MellonParam": @@ -196,7 +201,7 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input") + return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]) @classmethod def image_encoder(cls) -> "MellonParam": @@ -206,7 +211,7 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input") + return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input", required_block_params=["image_encoder"]) @classmethod def unet(cls) -> "MellonParam": @@ -236,7 +241,7 @@ class MellonParam: Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve the actual model. """ - return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input") + return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input", required_block_params=["controlnet"]) @classmethod def text_encoders(cls) -> "MellonParam": @@ -248,7 +253,7 @@ class MellonParam: 'repo_id': '...' } Use components.get_one(model_id) to retrieve each model. """ - return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input") + return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input", required_block_params=["text_encoder"]) @classmethod def controlnet_bundle(cls, display: str = "input") -> "MellonParam": @@ -263,7 +268,7 @@ class MellonParam: Output from Controlnet node, input to Denoise node. """ - return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display) + return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display, required_block_params="controlnet_image") @classmethod def ip_adapter(cls) -> "MellonParam": @@ -284,6 +289,86 @@ class MellonParam: return cls(name="doc", label="Doc", type="string", display="output") +DEFAULT_NODE_SPECS = { + "controlnet": None, + "denoise": { + "inputs": [ + MellonParam.embeddings(display="input"), + MellonParam.width(), + MellonParam.height(), + MellonParam.seed(), + MellonParam.num_inference_steps(), + MellonParam.guidance_scale(), + MellonParam.strength(), + MellonParam.image_latents_with_strength(), + MellonParam.image_latents(), + MellonParam.first_frame_latents(), + MellonParam.controlnet_bundle(display="input"), + ], + "model_inputs": [ + MellonParam.unet(), + MellonParam.guider(), + MellonParam.scheduler(), + ], + "outputs": [ + MellonParam.latents(display="output"), + MellonParam.latents_preview(), + MellonParam.doc(), + ], + "required_inputs": ["embeddings"], + "required_model_inputs": ["unet", "scheduler"], + "block_name": "denoise", + }, + "vae_encoder": { + "inputs": [ + MellonParam.image(), + ], + "model_inputs": [ + MellonParam.vae(), + ], + "outputs": [ + MellonParam.image_latents(display="output"), + MellonParam.doc(), + ], + "required_inputs": ["image"], + "required_model_inputs": ["vae"], + "block_name": "vae_encoder", + }, + "text_encoder": { + "inputs": [ + MellonParam.prompt(), + MellonParam.negative_prompt(), + ], + "model_inputs": [ + MellonParam.text_encoders(), + ], + "outputs": [ + MellonParam.embeddings(display="output"), + MellonParam.doc(), + ], + "required_inputs": ["prompt"], + "required_model_inputs": ["text_encoders"], + "block_name": "text_encoder", + }, + "decoder": { + "inputs": [ + MellonParam.latents(display="input"), + ], + "model_inputs": [ + MellonParam.vae(), + ], + "outputs": [ + MellonParam.images(), + MellonParam.videos(), + MellonParam.doc(), + ], + "required_inputs": ["latents"], + "required_model_inputs": ["vae"], + "block_name": "decode", + }, +} + + def mark_required(label: str, marker: str = " *") -> str: """Add required marker to label if not already present.""" if label.endswith(marker): @@ -458,20 +543,38 @@ class MellonPipelineConfig: default_dtype: Default dtype (e.g., "float16", "bfloat16") """ # Convert all node specs to Mellon format immediately - self.node_params = {} - for node_type, spec in node_specs.items(): - if spec is None: - self.node_params[node_type] = None - else: - self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type) + self.node_specs = node_specs self.label = label self.default_repo = default_repo self.default_dtype = default_dtype + @property + def node_params(self) -> Dict[str, Any]: + """Lazily compute node_params from node_specs.""" + params = {} + for node_type, spec in self.node_specs.items(): + if spec is None: + params[node_type] = None + else: + params[node_type] = node_spec_to_mellon_dict(spec, node_type) + return params + def __repr__(self) -> str: - node_types = list(self.node_params.keys()) - return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})" + lines = [f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"] + for node_type, spec in self.node_specs.items(): + if spec is None: + lines.append(f" {node_type}: None") + else: + inputs = [p.name for p in spec.get("inputs", [])] + model_inputs = [p.name for p in spec.get("model_inputs", [])] + outputs = [p.name for p in spec.get("outputs", [])] + lines.append(f" {node_type}:") + lines.append(f" inputs: {inputs}") + lines.append(f" model_inputs: {model_inputs}") + lines.append(f" outputs: {outputs}") + return "\n".join(lines) + def to_dict(self) -> Dict[str, Any]: """Convert to a JSON-serializable dictionary.""" @@ -622,3 +725,67 @@ class MellonPipelineConfig: return cls.from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.") + + @classmethod + def from_blocks( + cls, + blocks: "ModularPipelineBlocks", + template: Dict[str, Optional[Dict[str, Any]]] = None, + label: str = "", + default_repo: str = "", + default_dtype: str = "bfloat16", + ) -> "MellonPipelineConfig": + """ + Create MellonPipelineConfig by matching template against actual pipeline blocks. + """ + if template is None: + template = DEFAULT_NODE_SPECS + + sub_block_map = dict(blocks.sub_blocks) + + def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]: + """Filter template spec params based on what the block actually supports.""" + block_input_names = set(block.input_names) + block_output_names = set(block.intermediate_output_names) + block_component_names = set(block.component_names) + + filtered_inputs = [p for p in template_spec.get("inputs", []) if p.required_block_params is None or all(name in block_input_names for name in p.required_block_params)] + filtered_model_inputs = [p for p in template_spec.get("model_inputs", []) if p.required_block_params is None or all(name in block_component_names for name in p.required_block_params)] + filtered_outputs = [p for p in template_spec.get("outputs", []) if p.required_block_params is None or all(name in block_output_names for name in p.required_block_params)] + + filtered_input_names = {p.name for p in filtered_inputs} + filtered_model_input_names = {p.name for p in filtered_model_inputs} + + filtered_required_inputs = [r for r in template_spec.get("required_inputs", []) if r in filtered_input_names] + filtered_required_model_inputs = [r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names] + + + return { + "inputs": filtered_inputs, + "model_inputs": filtered_model_inputs, + "outputs": filtered_outputs, + "required_inputs": filtered_required_inputs, + "required_model_inputs": filtered_required_model_inputs, + "block_name": template_spec.get("block_name"), + } + + # Build node specs + node_specs = {} + for node_type, template_spec in template.items(): + if template_spec is None: + node_specs[node_type] = None + continue + + block_name = template_spec.get("block_name") + if block_name is None or block_name not in sub_block_map: + node_specs[node_type] = None + continue + + node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name]) + + return cls( + node_specs=node_specs, + label=label or getattr(blocks, "model_name", ""), + default_repo=default_repo, + default_dtype=default_dtype, + ) \ No newline at end of file