mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add componentspec and configspec
This commit is contained in:
@@ -743,3 +743,6 @@ class APGGuider:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
|
||||
|
||||
Guiders = Union[CFGGuider, PAGGuider, APGGuider]
|
||||
@@ -16,7 +16,7 @@ import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional, Type
|
||||
|
||||
|
||||
import torch
|
||||
@@ -338,11 +338,28 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component."""
|
||||
name: str
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
|
||||
default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"]
|
||||
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
name: str
|
||||
default: Any
|
||||
description: Optional[str] = None
|
||||
type_hint: Optional[Type] = None
|
||||
|
||||
class PipelineBlock:
|
||||
# YiYi Notes: do we need this?
|
||||
# pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list
|
||||
expected_components = []
|
||||
expected_configs = []
|
||||
|
||||
component_specs: List[ComponentSpec] = []
|
||||
config_specs: List[ConfigSpec] = []
|
||||
model_name = None
|
||||
|
||||
@property
|
||||
@@ -409,14 +426,45 @@ class PipelineBlock:
|
||||
desc = '\n'.join(desc) + '\n'
|
||||
|
||||
# Components section
|
||||
expected_components = set(getattr(self, "expected_components", []))
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
|
||||
loaded_components = set(self.components.keys())
|
||||
all_components = sorted(expected_components | loaded_components)
|
||||
all_components = sorted(expected_component_names | loaded_components)
|
||||
|
||||
main_components = []
|
||||
auxiliary_components = []
|
||||
for k in all_components:
|
||||
component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
|
||||
# Get component spec if available
|
||||
component_spec = next((comp for comp in expected_components if comp.name == k), None)
|
||||
|
||||
if k in loaded_components:
|
||||
component_type = type(self.components[k]).__name__
|
||||
component_str = f" - {k}={component_type}"
|
||||
|
||||
# Add expected type info if available
|
||||
if component_spec and component_spec.class_name:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
if expected_type != component_type:
|
||||
component_str += f" (expected: {expected_type})"
|
||||
else:
|
||||
# Component not loaded but expected
|
||||
if component_spec:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
component_str = f" - {k} (expected: {expected_type})"
|
||||
|
||||
# Add repo info if available
|
||||
if component_spec.default_repo:
|
||||
repo_info = component_spec.default_repo
|
||||
if component_spec.subfolder:
|
||||
repo_info += f", subfolder={component_spec.subfolder}"
|
||||
component_str += f" [{repo_info}]"
|
||||
else:
|
||||
component_str = f" - {k}"
|
||||
|
||||
if k in getattr(self, "auxiliary_components", []):
|
||||
auxiliary_components.append(component_str)
|
||||
else:
|
||||
@@ -793,18 +841,52 @@ class AutoPipelineBlocks:
|
||||
desc = '\n'.join(desc) + '\n'
|
||||
|
||||
# Components section
|
||||
expected_components = set(getattr(self, "expected_components", []))
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
|
||||
loaded_components = set(self.components.keys())
|
||||
all_components = sorted(expected_components | loaded_components)
|
||||
components_str = " Components:\n" + "\n".join(
|
||||
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
|
||||
for k in all_components
|
||||
)
|
||||
all_components = sorted(expected_component_names | loaded_components)
|
||||
|
||||
# Auxiliaries section
|
||||
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
|
||||
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
|
||||
)
|
||||
main_components = []
|
||||
for k in all_components:
|
||||
# Get component spec if available
|
||||
component_spec = next((comp for comp in expected_components if comp.name == k), None)
|
||||
|
||||
if k in loaded_components:
|
||||
component_type = type(self.components[k]).__name__
|
||||
component_str = f" - {k}={component_type}"
|
||||
|
||||
# Add expected type info if available
|
||||
if component_spec and component_spec.class_name:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
if expected_type != component_type:
|
||||
component_str += f" (expected: {expected_type})"
|
||||
else:
|
||||
# Component not loaded but expected
|
||||
if component_spec:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
component_str = f" - {k} (expected: {expected_type})"
|
||||
|
||||
# Add repo info if available
|
||||
if component_spec.default_repo:
|
||||
repo_info = component_spec.default_repo
|
||||
if component_spec.subfolder:
|
||||
repo_info += f", subfolder={component_spec.subfolder}"
|
||||
component_str += f" [{repo_info}]"
|
||||
else:
|
||||
component_str = f" - {k}"
|
||||
|
||||
|
||||
main_components.append(component_str)
|
||||
|
||||
components = "Components:\n" + "\n".join(main_components)
|
||||
|
||||
# Configs section
|
||||
expected_configs = set(getattr(self, "expected_configs", []))
|
||||
@@ -1188,19 +1270,54 @@ class SequentialPipelineBlocks:
|
||||
desc = '\n'.join(desc) + '\n'
|
||||
|
||||
# Components section
|
||||
expected_components = set(getattr(self, "expected_components", []))
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
|
||||
loaded_components = set(self.components.keys())
|
||||
all_components = sorted(expected_components | loaded_components)
|
||||
components_str = " Components:\n" + "\n".join(
|
||||
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
|
||||
for k in all_components
|
||||
)
|
||||
all_components = sorted(expected_component_names | loaded_components)
|
||||
|
||||
# Auxiliaries section
|
||||
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
|
||||
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
|
||||
)
|
||||
|
||||
main_components = []
|
||||
for k in all_components:
|
||||
# Get component spec if available
|
||||
component_spec = next((comp for comp in expected_components if comp.name == k), None)
|
||||
|
||||
if k in loaded_components:
|
||||
component_type = type(self.components[k]).__name__
|
||||
component_str = f" - {k}={component_type}"
|
||||
|
||||
# Add expected type info if available
|
||||
if component_spec and component_spec.class_name:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
if expected_type != component_type:
|
||||
component_str += f" (expected: {expected_type})"
|
||||
else:
|
||||
# Component not loaded but expected
|
||||
if component_spec:
|
||||
expected_type = component_spec.class_name
|
||||
if isinstance(expected_type, (list, tuple)):
|
||||
expected_type = expected_type[1] # Get class name from [module, class_name]
|
||||
component_str = f" - {k} (expected: {expected_type})"
|
||||
|
||||
# Add repo info if available
|
||||
if component_spec.default_repo:
|
||||
repo_info = component_spec.default_repo
|
||||
if component_spec.subfolder:
|
||||
repo_info += f", subfolder={component_spec.subfolder}"
|
||||
component_str += f" [{repo_info}]"
|
||||
else:
|
||||
component_str = f" - {k}"
|
||||
|
||||
|
||||
main_components.append(component_str)
|
||||
|
||||
components = "Components:\n" + "\n".join(main_components)
|
||||
|
||||
# Configs section
|
||||
expected_configs = set(getattr(self, "expected_configs", []))
|
||||
loaded_configs = set(self.configs.keys())
|
||||
@@ -1558,7 +1675,7 @@ class ModularPipeline(ConfigMixin):
|
||||
|
||||
return output
|
||||
|
||||
# YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline
|
||||
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
|
||||
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
|
||||
def to(self, *args, **kwargs):
|
||||
r"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user