1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load

This commit is contained in:
yiyixuxu
2025-05-18 18:58:26 +02:00
parent d0fbf745e6
commit 163341d3dd

View File

@@ -1651,54 +1651,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
def register_components(self, **kwargs):
"""
Register components with their corresponding specs.
This method is called when component changed or __init__ is called.
Register components with their corresponding specifications.
This method is responsible for:
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
2. Updates the modular_model_index.json configuration for serialization
4. Adds components to the component manager if one is attached
This method is called when:
- Components are first initialized in __init__:
- from_pretrained components not loaded during __init__ so they are registered as None;
- non from_pretrained components are created during __init__ and registered as the object itself
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec)
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"])
Args:
**kwargs: Keyword arguments where keys are component names and values are component objects.
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
Notes:
- Components must be created from ComponentSpec (have _diffusers_load_id attribute)
- When registering None for a component, it updates the modular_model_index.json config but sets attribute to None
"""
for name, module in kwargs.items():
# current component spec
component_spec = self._component_specs.get(name)
if component_spec is None:
logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'")
continue
# check if it is the first time registration, i.e. calling from __init__
is_registered = hasattr(self, name)
# make sure the component is created from ComponentSpec
if module is not None and not hasattr(module, "_diffusers_load_id"):
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
# actual library and class name of the module
if module is not None:
library, class_name = _fetch_class_library_tuple(module)
new_component_spec = ComponentSpec.from_component(name, module)
component_spec_dict = self._component_spec_to_dict(new_component_spec)
# actual library and class name of the module
library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
# "type_hint": ("diffusers", "UNet2DConditionModel"),
# "subfolder": "unet",
# "variant": None,
# "revision": None}
component_spec_dict = self._component_spec_to_dict(component_spec)
else:
# if module is None, e.g. self.register_components(unet=None) during __init__
# we do not update the spec,
# but we still need to update the modular_model_index.json config based oncomponent spec
library, class_name = None, None
# if module is None, we do not update the spec,
# but we still need to update the config to make sure it's synced with the component spec
# (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config)
new_component_spec = component_spec
component_spec_dict = self._component_spec_to_dict(component_spec)
# do not register if component is not to be loaded from pretrained
if new_component_spec.default_creation_method == "from_pretrained":
register_dict = {name: (library, class_name, component_spec_dict)}
else:
register_dict = {}
register_dict = {name: (library, class_name, component_spec_dict)}
# set the component as attribute
# if it is not set yet, just set it and skip the process to check and warn below
if not is_registered:
self.register_to_config(**register_dict)
self._component_specs[name] = new_component_spec
setattr(self, name, module)
if module is not None and self._component_manager is not None:
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
self._component_manager.add(name, module, self._collection)
continue
@@ -1707,10 +1721,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
if current_module is module:
logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping")
continue
# it module is not an instance of the expected type, still register it but with a warning
if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint):
logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}")
# warn if unregister
if current_module is not None and module is None:
@@ -1718,7 +1728,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
f"ModularLoader.register_components: setting '{name}' to None "
f"(was {current_module.__class__.__name__})"
)
# same type, new instance → debug
# same type, new instance → replace but send debug log
elif current_module is not None \
and module is not None \
and isinstance(module, current_module.__class__) \
@@ -1728,13 +1738,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
f"(same type {type(current_module).__name__}, new instance)"
)
# save modular_model_index.json config
# update modular_model_index.json config
self.register_to_config(**register_dict)
# update component spec
self._component_specs[name] = new_component_spec
# finally set models
setattr(self, name, module)
if module is not None and self._component_manager is not None:
# add to component manager if one is attached
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
self._component_manager.add(name, module, self._collection)
@@ -1758,6 +1767,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
config_dict = self.load_config(modular_repo, **kwargs)
for name, value in config_dict.items():
# only update component_spec for from_pretrained components
if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3:
library, class_name, component_spec_dict = value
component_spec = self._dict_to_component_spec(name, component_spec_dict)
@@ -1768,7 +1778,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
register_components_dict = {}
for name, component_spec in self._component_specs.items():
register_components_dict[name] = None
if component_spec.default_creation_method == "from_config":
component = component_spec.create()
else:
component = None
register_components_dict[name] = component
self.register_components(**register_components_dict)
default_configs = {}
@@ -1870,6 +1884,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
**kwargs: Component objects or configuration values to update:
- Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`)
- Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`)
- ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it
Raises:
ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute)
@@ -1893,22 +1908,52 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
unet=new_unet_model,
requires_safety_checker=False
)
# update with ComponentSpec objects
loader.update(
guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config")
)
```
"""
# extract component_specs_updates & config_specs_updates from `specs`
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)}
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)}
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
for name, component in passed_components.items():
if not hasattr(component, "_diffusers_load_id"):
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
# YiYi TODO: remove this if we remove support for non config mixin components in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
f"The passed component '{name}' is not supported in update() method "
f"because it is not supported in `ComponentSpec.from_component()`. "
f"Please pass a ComponentSpec object instead."
)
current_component_spec = self._component_specs[name]
# warn if type changed
if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint):
logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}")
# update _component_specs based on the new component
new_component_spec = ComponentSpec.from_component(name, component)
self._component_specs[name] = new_component_spec
if len(kwargs) > 0:
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
self.register_components(**passed_components)
created_components = {}
for name, component_spec in passed_component_specs.items():
if component_spec.default_creation_method == "from_pretrained":
raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method")
created_components[name] = component_spec.create()
current_component_spec = self._component_specs[name]
# warn if type changed
if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint):
logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}")
# update _component_specs based on the user passed component_spec
self._component_specs[name] = component_spec
self.register_components(**passed_components, **created_components)
config_to_register = {}
@@ -1932,8 +1977,9 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc.
"""
# if not specific name, load all the components with default_creation_method == "from_pretrained"
if component_names is None:
component_names = list(self._component_specs.keys())
component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"]
elif not isinstance(component_names, list):
component_names = [component_names]
@@ -1958,7 +2004,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
# check if the default is specified
component_load_kwargs[key] = value["default"]
try:
components_to_register[name] = spec.create(**component_load_kwargs)
components_to_register[name] = spec.load(**component_load_kwargs)
except Exception as e:
logger.warning(f"Failed to create component '{name}': {e}")
@@ -1986,7 +2032,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs)
expected_component = set(config_dict.pop("_components_names"))
@@ -2010,7 +2056,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
else:
# append a empty component spec for these not in modular_model_index
component_specs.append(ComponentSpec(name=name, default_creation_method="from_config"))
return cls(component_specs + config_specs)
return cls(component_specs + config_specs, component_manager=component_manager, collection=collection)
@staticmethod