From 87f63d424a6efb0d309ced1e67b827bd10881b7c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 22 May 2025 11:50:36 +0200 Subject: [PATCH] modular node! --- .../modular_pipeline_utils.py | 4 +- src/diffusers/modular_pipelines/node_utils.py | 438 ++++++++++++------ 2 files changed, 305 insertions(+), 137 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 0c6d1b5855..6d6704f4eb 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -246,7 +246,7 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -258,7 +258,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" - kwargs_type: str = None + kwargs_type: str = None # YiYi notes: remove this feature (maybe) def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 2dfb85a5f9..9ee9c06927 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,7 +1,10 @@ from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin from .modular_pipeline_utils import InputParam, OutputParam from ..image_processor import PipelineImageInput +from pathlib import Path +import json +import os from typing import Union, List, Optional, Tuple import torch @@ -77,29 +80,8 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = { "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") } -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} +SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} + DEFAULT_PARAM_MAPS = { "prompt": { @@ -191,7 +173,7 @@ DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker" DEFAULT_PARAMS_GROUPS_KEYS = { "text_encoders": ["text_encoder", "tokenizer"], "ip_adapter_embeds": ["ip_adapter_embeds"], - "text_embeds": ["prompt_embeds"], + "prompt_embeddings": ["prompt_embeds"], } @@ -200,144 +182,330 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None """ + if name is None: + return None for group_name, group_keys in group_params_keys.items(): for group_key in group_keys: if group_key in name: return group_name return None + + +class ModularNode(ConfigMixin): -class MellonNode(ConfigMixin): - - block_class = None config_name = "node_config.json" + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + return cls(blocks, **kwargs) - def __init__(self, category=DEFAULT_CATEGORY, label=None, input_params=None, intermediate_params=None, component_params=None, output_params=None): - self.blocks = self.block_class() + def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): + self.blocks = blocks if label is None: label = self.blocks.__class__.__name__ - - expected_inputs = [inp.name for inp in self.blocks.inputs] - expected_intermediates = [inp.name for inp in self.blocks.intermediates_inputs] - expected_components = [comp.name for comp in self.blocks.expected_components] - expected_outputs = [out.name for out in self.blocks.intermediates_outputs] + # blocks param name -> mellon param name + self.name_mapping = {} - if input_params is None: - input_params ={} - for inp in self.blocks.inputs: - # create a param dict for each input e.g. for prompt, param = {"prompt": {"label": "Prompt", "type": "string", "default": "a bear sitting in a chair drinking a milkshake", "display": "textarea"} } - param = {} - if inp.name in DEFAULT_PARAM_MAPS: - # first check if it's in the default param map, if so, directly use that - param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] - elif inp.required: - group_name = get_group_name(inp.name) - if group_name: - param = group_name - else: - # if not, check if it's in the SDXL input schema, if so, - # 1. use the type hint to determine the type - # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} - inp_spec = SDXL_INPUTS_SCHEMA.get(inp.name, None) - if inp_spec: - type_str = str(inp_spec.type_hint).lower() - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param[inp.name] = type_param - param[inp.name]["display"] = "input" - break - else: - param = inp.name - # add the param dict to the inp_params dict - if param: - input_params[inp.name] = param - - if intermediate_params is None: - intermediate_params = {} - for inp in self.blocks.intermediates_inputs: - param = {} - if inp.name in DEFAULT_PARAM_MAPS: - param[inp.name] = DEFAULT_PARAM_MAPS[inp.name] - elif inp.required: - group_name = get_group_name(inp.name) - if group_name: - param = group_name - else: - inp_spec = SDXL_INTERMEDIATE_INPUTS_SCHEMA.get(inp.name, None) - if inp_spec: - type_str = str(inp_spec.type_hint).lower() - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param[inp.name] = type_param - param[inp.name]["display"] = "input" - break - else: - param = inp.name - # add the param dict to the intermediate_params dict - if param: - intermediate_params[inp.name] = param - - if component_params is None: - component_params = {} - for comp in self.blocks.expected_components: - to_exclude = False - for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: - if exclude_key in comp.name: - to_exclude = True + input_params = {} + # pass or create a default param dict for each input + # e.g. for prompt, + # prompt = { + # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type + # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) + # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} + # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} + inputs = self.blocks.inputs + self.blocks.intermediates_inputs + for inp in inputs: + param = kwargs.pop(inp.name, None) + if param: + # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) + input_params[inp.name] = param + mellon_name = param.pop("name", inp.name) + if mellon_name != inp.name: + self.name_mapping[inp.name] = mellon_name + continue + + if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + continue + + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param = DEFAULT_PARAM_MAPS[inp.name].copy() + elif get_group_name(inp.name): + param = get_group_name(inp.name) + if inp.name not in self.name_mapping: + self.name_mapping[inp.name] = param + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + if inp.type_hint is not None: + type_str = str(inp.type_hint).lower() + else: + inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) + type_str = str(inp_spec.type_hint).lower() if inp_spec else "" + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param = type_param.copy() + param["label"] = inp.name + param["display"] = "input" break - if to_exclude: - continue - - param = {} - group_name = get_group_name(comp.name) + else: + param = inp.name + # add the param dict to the inp_params dict + input_params[inp.name] = param + + + component_params = {} + for comp in self.blocks.expected_components: + param = kwargs.pop(comp.name, None) + if param: + component_params[comp.name] = param + mellon_name = param.pop("name", comp.name) + if mellon_name != comp.name: + self.name_mapping[comp.name] = mellon_name + continue + + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + if get_group_name(comp.name): + param = get_group_name(comp.name) + if comp.name not in self.name_mapping: + self.name_mapping[comp.name] = param + elif comp.name in DEFAULT_MODEL_KEYS: + param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} + else: + param = comp.name + # add the param dict to the model_params dict + component_params[comp.name] = param + + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.blocks.keys())[-1] + outputs = self.blocks.blocks[last_block_name].intermediates_outputs + else: + outputs = self.blocks.intermediates_outputs + + for out in outputs: + param = kwargs.pop(out.name, None) + if param: + output_params[out.name] = param + mellon_name = param.pop("name", out.name) + if mellon_name != out.name: + self.name_mapping[out.name] = mellon_name + continue + + if out.name in DEFAULT_PARAM_MAPS: + param = DEFAULT_PARAM_MAPS[out.name].copy() + param["display"] = "output" + else: + group_name = get_group_name(out.name) if group_name: param = group_name - elif comp.name in DEFAULT_MODEL_KEYS: - param[comp.name] = { - "label": comp.name, - "type": "diffusers_auto_model", - "display": "input", - } + if out.name not in self.name_mapping: + self.name_mapping[out.name] = param else: - param = comp.name - # add the param dict to the model_params dict - if param: - component_params[comp.name] = param - - if output_params is None: - output_params = {} - if isinstance(self.blocks, SequentialPipelineBlocks): - last_block_name = list(self.blocks.blocks.keys())[-1] - outputs = self.blocks.blocks[last_block_name].intermediates_outputs - else: - outputs = self.blocks.intermediates_outputs + param = out.name + # add the param dict to the outputs dict + output_params[out.name] = param - for out in outputs: - param = {} - if out.name in DEFAULT_PARAM_MAPS: - param[out.name] = DEFAULT_PARAM_MAPS[out.name] - param[out.name]["display"] = "output" - else: - group_name = get_group_name(out.name) - if group_name: - param = group_name - else: - param = out.name - # add the param dict to the outputs dict - if param: - output_params[out.name] = param + if len(kwargs) > 0: + logger.warning(f"Unused kwargs: {kwargs}") register_dict = { "category": category, "label": label, "input_params": input_params, - "intermediate_params": intermediate_params, "component_params": component_params, "output_params": output_params, + "name_mapping": self.name_mapping, } self.register_to_config(**register_dict) + + def setup(self, components, collection=None): + self.blocks.setup_loader(component_manager=components, collection=collection) + self._components_manager = components + + @property + def mellon_config(self): + return self._convert_to_mellon_config() + + def _convert_to_mellon_config(self): + + node = {} + node["label"] = self.config.label + node["category"] = self.config.category + + node_param = {} + for inp_name, inp_param in self.config.input_params.items(): + if inp_name in self.name_mapping: + mellon_name = self.name_mapping[inp_name] + else: + mellon_name = inp_name + if isinstance(inp_param, str): + param = { + "label": inp_param, + "type": inp_param, + "display": "input", + } + else: + param = inp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") + + for comp_name, comp_param in self.config.component_params.items(): + if comp_name in self.name_mapping: + mellon_name = self.name_mapping[comp_name] + else: + mellon_name = comp_name + if isinstance(comp_param, str): + param = { + "label": comp_param, + "type": comp_param, + "display": "input", + } + else: + param = comp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") + + + for out_name, out_param in self.config.output_params.items(): + if out_name in self.name_mapping: + mellon_name = self.name_mapping[out_name] + else: + mellon_name = out_name + if isinstance(out_param, str): + param = { + "label": out_param, + "type": out_param, + "display": "output", + } + else: + param = out_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") + node["params"] = node_param + return node + + def save_mellon_config(self, file_path): + """ + Save the Mellon configuration to a JSON file. + + Args: + file_path (str or Path): Path where the JSON file will be saved + + Returns: + Path: Path to the saved config file + """ + file_path = Path(file_path) + + # Create directory if it doesn't exist + os.makedirs(file_path.parent, exist_ok=True) + + # Create a combined dictionary with module definition and name mapping + config = { + "module": self.mellon_config, + "name_mapping": self.name_mapping + } + + # Save the config to file + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2) + + logger.info(f"Mellon config and name mapping saved to {file_path}") + + return file_path + + @classmethod + def load_mellon_config(cls, file_path): + """ + Load a Mellon configuration from a JSON file. + + Args: + file_path (str or Path): Path to the JSON file containing Mellon config + + Returns: + dict: The loaded combined configuration containing 'module' and 'name_mapping' + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + logger.info(f"Mellon config loaded from {file_path}") + + + return config + + def process_inputs(self, **kwargs): + + params_components = {} + for comp_name, comp_param in self.config.component_params.items(): + logger.debug(f"component: {comp_name}") + mellon_comp_name = self.name_mapping.get(comp_name, comp_name) + if mellon_comp_name in kwargs: + if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: + comp = kwargs[mellon_comp_name].pop(comp_name) + else: + comp = kwargs.pop(mellon_comp_name) + if comp: + params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) + + + params_run = {} + for inp_name, inp_param in self.config.input_params.items(): + logger.debug(f"input: {inp_name}") + mellon_inp_name = self.name_mapping.get(inp_name, inp_name) + if mellon_inp_name in kwargs: + if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: + inp = kwargs[mellon_inp_name].pop(inp_name) + else: + inp = kwargs.pop(mellon_inp_name) + if inp is not None: + params_run[inp_name] = inp + + return_output_names = list(self.config.output_params.keys()) + + return params_components, params_run, return_output_names + + def execute(self, **kwargs): + params_components, params_run, return_output_names = self.process_inputs(**kwargs) + + self.blocks.loader.update(**params_components) + output = self.blocks.run(**params_run, output=return_output_names) + return output