1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

modular node!

This commit is contained in:
yiyixuxu
2025-05-22 11:50:36 +02:00
parent 29de29f02c
commit 87f63d424a
2 changed files with 305 additions and 137 deletions

View File

@@ -246,7 +246,7 @@ class InputParam:
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -258,7 +258,7 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"

View File

@@ -1,7 +1,10 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
import json
import os
from typing import Union, List, Optional, Tuple
import torch
@@ -77,29 +80,8 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images")
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
@@ -191,7 +173,7 @@ DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"text_embeds": ["prompt_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
@@ -200,144 +182,330 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
Get the group name for a given parameter name, if not part of a group, return None
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
class MellonNode(ConfigMixin):
block_class = None
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(blocks, **kwargs)
def __init__(self, category=DEFAULT_CATEGORY, label=None, input_params=None, intermediate_params=None, component_params=None, output_params=None):
self.blocks = self.block_class()
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
expected_inputs = [inp.name for inp in self.blocks.inputs]
expected_intermediates = [inp.name for inp in self.blocks.intermediates_inputs]
expected_components = [comp.name for comp in self.blocks.expected_components]
expected_outputs = [out.name for out in self.blocks.intermediates_outputs]
# blocks param name -> mellon param name
self.name_mapping = {}
if input_params is None:
input_params ={}
for inp in self.blocks.inputs:
# create a param dict for each input e.g. for prompt, param = {"prompt": {"label": "Prompt", "type": "string", "default": "a bear sitting in a chair drinking a milkshake", "display": "textarea"} }
param = {}
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param[inp.name] = DEFAULT_PARAM_MAPS[inp.name]
elif inp.required:
group_name = get_group_name(inp.name)
if group_name:
param = group_name
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
inp_spec = SDXL_INPUTS_SCHEMA.get(inp.name, None)
if inp_spec:
type_str = str(inp_spec.type_hint).lower()
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param[inp.name] = type_param
param[inp.name]["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
if param:
input_params[inp.name] = param
if intermediate_params is None:
intermediate_params = {}
for inp in self.blocks.intermediates_inputs:
param = {}
if inp.name in DEFAULT_PARAM_MAPS:
param[inp.name] = DEFAULT_PARAM_MAPS[inp.name]
elif inp.required:
group_name = get_group_name(inp.name)
if group_name:
param = group_name
else:
inp_spec = SDXL_INTERMEDIATE_INPUTS_SCHEMA.get(inp.name, None)
if inp_spec:
type_str = str(inp_spec.type_hint).lower()
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param[inp.name] = type_param
param[inp.name]["display"] = "input"
break
else:
param = inp.name
# add the param dict to the intermediate_params dict
if param:
intermediate_params[inp.name] = param
if component_params is None:
component_params = {}
for comp in self.blocks.expected_components:
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
if to_exclude:
continue
param = {}
group_name = get_group_name(comp.name)
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
elif comp.name in DEFAULT_MODEL_KEYS:
param[comp.name] = {
"label": comp.name,
"type": "diffusers_auto_model",
"display": "input",
}
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = comp.name
# add the param dict to the model_params dict
if param:
component_params[comp.name] = param
if output_params is None:
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
for out in outputs:
param = {}
if out.name in DEFAULT_PARAM_MAPS:
param[out.name] = DEFAULT_PARAM_MAPS[out.name]
param[out.name]["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
else:
param = out.name
# add the param dict to the outputs dict
if param:
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"intermediate_params": intermediate_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {
"module": self.mellon_config,
"name_mapping": self.name_mapping
}
# Save the config to file
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.blocks.loader.update(**params_components)
output = self.blocks.run(**params_run, output=return_output_names)
return output