1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add componentspec and configspec

This commit is contained in:
yiyixuxu
2025-02-27 19:18:10 +01:00
parent 96795afc72
commit ee842839ef
3 changed files with 1215 additions and 39 deletions

View File

@@ -743,3 +743,6 @@ class APGGuider:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
Guiders = Union[CFGGuider, PAGGuider, APGGuider]

View File

@@ -16,7 +16,7 @@ import traceback
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional, Type
import torch
@@ -338,11 +338,28 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
return output
@dataclass
class ComponentSpec:
"""Specification for a pipeline component."""
name: str
type_hint: Optional[Type] = None
description: Optional[str] = None
default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"]
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
default: Any
description: Optional[str] = None
type_hint: Optional[Type] = None
class PipelineBlock:
# YiYi Notes: do we need this?
# pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list
expected_components = []
expected_configs = []
component_specs: List[ComponentSpec] = []
config_specs: List[ConfigSpec] = []
model_name = None
@property
@@ -409,14 +426,45 @@ class PipelineBlock:
desc = '\n'.join(desc) + '\n'
# Components section
expected_components = set(getattr(self, "expected_components", []))
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_components | loaded_components)
all_components = sorted(expected_component_names | loaded_components)
main_components = []
auxiliary_components = []
for k in all_components:
component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
else:
component_str = f" - {k}"
if k in getattr(self, "auxiliary_components", []):
auxiliary_components.append(component_str)
else:
@@ -793,18 +841,52 @@ class AutoPipelineBlocks:
desc = '\n'.join(desc) + '\n'
# Components section
expected_components = set(getattr(self, "expected_components", []))
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_components | loaded_components)
components_str = " Components:\n" + "\n".join(
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
for k in all_components
)
all_components = sorted(expected_component_names | loaded_components)
# Auxiliaries section
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
)
main_components = []
for k in all_components:
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
else:
component_str = f" - {k}"
main_components.append(component_str)
components = "Components:\n" + "\n".join(main_components)
# Configs section
expected_configs = set(getattr(self, "expected_configs", []))
@@ -1188,19 +1270,54 @@ class SequentialPipelineBlocks:
desc = '\n'.join(desc) + '\n'
# Components section
expected_components = set(getattr(self, "expected_components", []))
expected_components = getattr(self, "expected_components", [])
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
loaded_components = set(self.components.keys())
all_components = sorted(expected_components | loaded_components)
components_str = " Components:\n" + "\n".join(
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
for k in all_components
)
all_components = sorted(expected_component_names | loaded_components)
# Auxiliaries section
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
)
main_components = []
for k in all_components:
# Get component spec if available
component_spec = next((comp for comp in expected_components if comp.name == k), None)
if k in loaded_components:
component_type = type(self.components[k]).__name__
component_str = f" - {k}={component_type}"
# Add expected type info if available
if component_spec and component_spec.class_name:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
if expected_type != component_type:
component_str += f" (expected: {expected_type})"
else:
# Component not loaded but expected
if component_spec:
expected_type = component_spec.class_name
if isinstance(expected_type, (list, tuple)):
expected_type = expected_type[1] # Get class name from [module, class_name]
component_str = f" - {k} (expected: {expected_type})"
# Add repo info if available
if component_spec.default_repo:
repo_info = component_spec.default_repo
if component_spec.subfolder:
repo_info += f", subfolder={component_spec.subfolder}"
component_str += f" [{repo_info}]"
else:
component_str = f" - {k}"
main_components.append(component_str)
components = "Components:\n" + "\n".join(main_components)
# Configs section
expected_configs = set(getattr(self, "expected_configs", []))
loaded_configs = set(self.configs.keys())
@@ -1558,7 +1675,7 @@ class ModularPipeline(ConfigMixin):
return output
# YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
def to(self, *args, **kwargs):
r"""