mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
modularpipeloine -> modularpipelineloader, setup_loader, make loader configmixin etc
This commit is contained in:
@@ -30,7 +30,11 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from .pipeline_loading_utils import _get_pipeline_class
|
||||
from .modular_pipeline_util import (
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_input_params,
|
||||
@@ -41,16 +45,16 @@ from .modular_pipeline_util import (
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
MODULAR_LOADER_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularLoader"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -148,45 +152,6 @@ class BlockState:
|
||||
return f"BlockState(\n{attributes}\n)"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component."""
|
||||
name: str
|
||||
type_hint: Type
|
||||
description: Optional[str] = None
|
||||
obj: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
|
||||
default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"]
|
||||
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
name: str
|
||||
default: Any
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
|
||||
|
||||
class PipelineBlock:
|
||||
|
||||
model_name = None
|
||||
@@ -1027,21 +992,109 @@ class ModularPipelineMixin:
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.components_manager = None
|
||||
self.components_manager_prefix = ""
|
||||
self.components_state = None
|
||||
|
||||
# YiYi TODO: not sure this is the best method name
|
||||
def compile(self, components_manager: ComponentsManager, label: Optional[str] = None):
|
||||
self.components_manager = components_manager
|
||||
self.components_manager_prefix = "" if label is None else f"{label}_"
|
||||
self.components_state = ComponentsState(self.expected_components, self.expected_configs)
|
||||
def register_loader(self, global_components_manager: ComponentsManager, label: Optional[str] = None):
|
||||
self._global_components_manager = global_components_manager
|
||||
self._label = label
|
||||
|
||||
#YiYi TODO: add validation for kwargs?
|
||||
def setup_loader(self, **kwargs):
|
||||
"""
|
||||
Set up the components loader with repository information.
|
||||
|
||||
components_to_add = self.components_manager.get(f"{self.components_manager_prefix}*")
|
||||
self.components_state.update_states(self.expected_components, self.expected_configs, **components_to_add)
|
||||
Args:
|
||||
**kwargs: Configuration for component loading.
|
||||
- repo: Default repository to use for all components
|
||||
- For individual components, pass a tuple of (repo, subfolder)
|
||||
e.g., text_encoder=("repo_name", "text_encoder")
|
||||
|
||||
Examples:
|
||||
# Set repo for all components (subfolder will be component name)
|
||||
setup_loader(repo="stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# Set specific repo/subfolder for individual components
|
||||
setup_loader(
|
||||
unet=("stabilityai/stable-diffusion-xl-base-1.0", "unet"),
|
||||
text_encoder=("stabilityai/stable-diffusion-xl-base-1.0", "text_encoder")
|
||||
)
|
||||
|
||||
# Set default repo and override for specific components
|
||||
setup_loader(
|
||||
repo="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
unet=(""stabilityai/stable-diffusion-xl-refiner-1.0", "unet")
|
||||
)
|
||||
"""
|
||||
|
||||
# Create deep copies to avoid modifying the original specs
|
||||
components_specs = deepcopy(self.expected_components)
|
||||
config_specs = deepcopy(self.expected_configs)
|
||||
|
||||
expected_component_names = set([c.name for c in components_specs])
|
||||
expected_config_names = set([c.name for c in config_specs])
|
||||
|
||||
# Check if a default repo is provided
|
||||
repo = kwargs.pop("repo", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
passed_component_kwargs = {k: kwargs.pop(k) for k in expected_component_names if k in kwargs}
|
||||
passed_config_kwargs = {k: kwargs.pop(k) for k in expected_config_names if k in kwargs}
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused keyword arguments: {kwargs.keys()}. This input will be ignored.")
|
||||
|
||||
for name, value in passed_component_kwargs.items():
|
||||
if not isinstance(value, (tuple, list, str)):
|
||||
raise ValueError(f"Invalid value for component '{name}': {value}. Expected a string, tuple or list")
|
||||
elif isinstance(value, (tuple, list)) and len(value) > 2:
|
||||
raise ValueError(f"Invalid value for component '{name}': {value}. Expected a tuple or list of length 1 or 2.")
|
||||
|
||||
for name, value in passed_config_kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Invalid value for config '{name}': {value}. Expected a string")
|
||||
|
||||
# First apply default repo to all components if provided
|
||||
if repo is not None:
|
||||
for component_spec in components_specs:
|
||||
# components defined with a config are classes like image_processor or guider,
|
||||
# skip setting loading related attributes for them, they should be initialized with the default config
|
||||
if component_spec.config is None:
|
||||
component_spec.repo = repo
|
||||
|
||||
# YiYi TODO: should also accept `revision` and `variant` as a dict here so user can set different values for different components
|
||||
if revision is not None:
|
||||
component_spec.revision = revision
|
||||
if variant is not None:
|
||||
component_spec.variant = variant
|
||||
for config_spec in config_specs:
|
||||
config_spec.repo = repo
|
||||
|
||||
# apply component-specific overrides
|
||||
for name, value in passed_component_kwargs.items():
|
||||
if not isinstance(value, (tuple, list)):
|
||||
value = (value,)
|
||||
# Find the matching component spec
|
||||
for component_spec in components_specs:
|
||||
if component_spec.name == name:
|
||||
# Handle tuple of (repo, subfolder)
|
||||
component_spec.repo = value[0]
|
||||
if len(value) > 1:
|
||||
component_spec.subfolder = value[1]
|
||||
break
|
||||
|
||||
# apply config overrides
|
||||
for name, value in passed_config_kwargs.items():
|
||||
for config_spec in config_specs:
|
||||
if config_spec.name == name:
|
||||
config_spec.repo = value
|
||||
break
|
||||
|
||||
# Import components loader (it is model-specific class)
|
||||
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name]
|
||||
diffusers_module = importlib.import_module(self.__module__.split(".")[0])
|
||||
loader_class = getattr(diffusers_module, loader_class_name)
|
||||
|
||||
# Create the loader with the updated specs
|
||||
self.loader = loader_class(components_specs, config_specs)
|
||||
|
||||
|
||||
@property
|
||||
@@ -1105,24 +1158,69 @@ class ModularPipelineMixin:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
|
||||
class ComponentsState(ConfigMixin):
|
||||
# YiYi NOTE: not sure if this needs to be a ConfigMixin
|
||||
class ModularPipelineLoader(ConfigMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
Base class for all Modular pipelines loaders.
|
||||
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
config_name = "modular_model_index.json"
|
||||
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
for name, module in kwargs.items():
|
||||
|
||||
repo = self.components_specs[name].repo
|
||||
subfolder = self.components_specs[name].subfolder
|
||||
# retrieve library
|
||||
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
|
||||
register_dict = {name: (None, None, (None, None))}
|
||||
else:
|
||||
library, class_name = _fetch_class_library_tuple(module)
|
||||
register_dict = {name: (library, class_name, (repo, subfolder))}
|
||||
|
||||
# save model index config
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
if name in self.__dict__ and hasattr(self.config, name):
|
||||
|
||||
repo = self.components_specs[name].repo
|
||||
subfolder = self.components_specs[name].subfolder
|
||||
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
library, class_name = _fetch_class_library_tuple(value)
|
||||
register_dict = {name: (library, class_name, (repo, subfolder))}
|
||||
else:
|
||||
register_dict = {name: (None, None, (None, None))}
|
||||
|
||||
self.register_to_config(**register_dict)
|
||||
else:
|
||||
self.register_to_config(**{name: value})
|
||||
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
def __init__(self, component_specs, config_specs):
|
||||
|
||||
self.components_specs = deepcopy(component_specs)
|
||||
self.configs_specs = deepcopy(config_specs)
|
||||
|
||||
for component_spec in component_specs:
|
||||
if component_spec.obj is not None:
|
||||
setattr(self, component_spec.name, component_spec.obj)
|
||||
if component_spec.config is not None:
|
||||
component_obj = component_spec.type_hint(**component_spec.config)
|
||||
self.register_components(component_spec.name, component_obj)
|
||||
else:
|
||||
setattr(self, component_spec.name, None)
|
||||
self.register_components(component_spec.name, None)
|
||||
|
||||
default_configs = {}
|
||||
for config_spec in config_specs:
|
||||
default_configs[config_spec.name] = config_spec.default
|
||||
default_configs[config_spec.name] = config_spec.value
|
||||
self.register_to_config(**default_configs)
|
||||
|
||||
|
||||
@@ -1187,7 +1285,7 @@ class ComponentsState(ConfigMixin):
|
||||
components[component_spec.name] = getattr(self, component_spec.name)
|
||||
return components
|
||||
|
||||
def update_states(self, expected_components, expected_configs, **kwargs):
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for
|
||||
each pipeline block, does not need to be updated by users. Logs if existing non-None components are being
|
||||
@@ -1197,7 +1295,7 @@ class ComponentsState(ConfigMixin):
|
||||
kwargs (dict): Keyword arguments to update the states.
|
||||
"""
|
||||
|
||||
for component in expected_components:
|
||||
for component in self.components_specs:
|
||||
if component.name in kwargs:
|
||||
if hasattr(self, component.name) and getattr(self, component.name) is not None:
|
||||
current_component = getattr(self, component.name)
|
||||
@@ -1217,10 +1315,10 @@ class ComponentsState(ConfigMixin):
|
||||
f"with new value (type: {type(new_component).__name__})"
|
||||
)
|
||||
|
||||
setattr(self.components_state, component.name, kwargs.pop(component.name))
|
||||
setattr(self, component.name, kwargs.pop(component.name))
|
||||
|
||||
configs_to_add = {}
|
||||
for config in expected_configs:
|
||||
for config in self.configs_specs:
|
||||
if config.name in kwargs:
|
||||
configs_to_add[config.name] = kwargs.pop(config.name)
|
||||
self.register_to_config(**configs_to_add)
|
||||
@@ -1228,3 +1326,4 @@ class ComponentsState(ConfigMixin):
|
||||
# YiYi TODO: should support to method
|
||||
def to(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineLoader,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
InputParam,
|
||||
@@ -58,6 +58,7 @@ from transformers import (
|
||||
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...guiders import GuiderType, ClassifierFreeGuidance
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -646,7 +647,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -741,8 +742,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()),
|
||||
ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})),
|
||||
ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True})),
|
||||
]
|
||||
|
||||
|
||||
@@ -1728,7 +1729,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [ConfigSpec("requires_aesthetics_score", default=False),]
|
||||
return [ConfigSpec("requires_aesthetics_score", False),]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -2063,7 +2064,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
@@ -2332,11 +2333,11 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2763,8 +2764,8 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetUnionModel),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3179,7 +3180,7 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor())
|
||||
ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}))
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3570,9 +3571,14 @@ SDXL_SUPPORTED_BLOCKS = {
|
||||
}
|
||||
|
||||
|
||||
# YiYi TODO: rename to components etc. and not inherit from ModularPipeline
|
||||
class StableDiffusionXLComponentStates(
|
||||
ModularPipeline,
|
||||
# YiYi Notes: model specific components:
|
||||
## (1) it should inherit from ModularPipelineComponents
|
||||
## (2) acts like a container that holds components and configs
|
||||
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
|
||||
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
|
||||
## (5) how to use together with Components_manager?
|
||||
class StableDiffusionXLModularLoader(
|
||||
ModularPipelineLoader,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
Reference in New Issue
Block a user