mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load
This commit is contained in:
@@ -1651,54 +1651,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
"""
|
||||
Register components with their corresponding specs.
|
||||
This method is called when component changed or __init__ is called.
|
||||
|
||||
Register components with their corresponding specifications.
|
||||
|
||||
This method is responsible for:
|
||||
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
|
||||
2. Updates the modular_model_index.json configuration for serialization
|
||||
4. Adds components to the component manager if one is attached
|
||||
|
||||
This method is called when:
|
||||
- Components are first initialized in __init__:
|
||||
- from_pretrained components not loaded during __init__ so they are registered as None;
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"])
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
|
||||
|
||||
Notes:
|
||||
- Components must be created from ComponentSpec (have _diffusers_load_id attribute)
|
||||
- When registering None for a component, it updates the modular_model_index.json config but sets attribute to None
|
||||
"""
|
||||
for name, module in kwargs.items():
|
||||
|
||||
# current component spec
|
||||
component_spec = self._component_specs.get(name)
|
||||
if component_spec is None:
|
||||
logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'")
|
||||
continue
|
||||
|
||||
# check if it is the first time registration, i.e. calling from __init__
|
||||
is_registered = hasattr(self, name)
|
||||
|
||||
# make sure the component is created from ComponentSpec
|
||||
if module is not None and not hasattr(module, "_diffusers_load_id"):
|
||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||
|
||||
# actual library and class name of the module
|
||||
|
||||
if module is not None:
|
||||
library, class_name = _fetch_class_library_tuple(module)
|
||||
new_component_spec = ComponentSpec.from_component(name, module)
|
||||
component_spec_dict = self._component_spec_to_dict(new_component_spec)
|
||||
# actual library and class name of the module
|
||||
library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
|
||||
|
||||
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
|
||||
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
|
||||
# "type_hint": ("diffusers", "UNet2DConditionModel"),
|
||||
# "subfolder": "unet",
|
||||
# "variant": None,
|
||||
# "revision": None}
|
||||
component_spec_dict = self._component_spec_to_dict(component_spec)
|
||||
|
||||
else:
|
||||
# if module is None, e.g. self.register_components(unet=None) during __init__
|
||||
# we do not update the spec,
|
||||
# but we still need to update the modular_model_index.json config based oncomponent spec
|
||||
library, class_name = None, None
|
||||
# if module is None, we do not update the spec,
|
||||
# but we still need to update the config to make sure it's synced with the component spec
|
||||
# (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config)
|
||||
new_component_spec = component_spec
|
||||
component_spec_dict = self._component_spec_to_dict(component_spec)
|
||||
|
||||
# do not register if component is not to be loaded from pretrained
|
||||
if new_component_spec.default_creation_method == "from_pretrained":
|
||||
register_dict = {name: (library, class_name, component_spec_dict)}
|
||||
else:
|
||||
register_dict = {}
|
||||
register_dict = {name: (library, class_name, component_spec_dict)}
|
||||
|
||||
# set the component as attribute
|
||||
# if it is not set yet, just set it and skip the process to check and warn below
|
||||
if not is_registered:
|
||||
self.register_to_config(**register_dict)
|
||||
self._component_specs[name] = new_component_spec
|
||||
setattr(self, name, module)
|
||||
if module is not None and self._component_manager is not None:
|
||||
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
|
||||
self._component_manager.add(name, module, self._collection)
|
||||
continue
|
||||
|
||||
@@ -1707,10 +1721,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
if current_module is module:
|
||||
logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping")
|
||||
continue
|
||||
|
||||
# it module is not an instance of the expected type, still register it but with a warning
|
||||
if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint):
|
||||
logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}")
|
||||
|
||||
# warn if unregister
|
||||
if current_module is not None and module is None:
|
||||
@@ -1718,7 +1728,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
f"ModularLoader.register_components: setting '{name}' to None "
|
||||
f"(was {current_module.__class__.__name__})"
|
||||
)
|
||||
# same type, new instance → debug
|
||||
# same type, new instance → replace but send debug log
|
||||
elif current_module is not None \
|
||||
and module is not None \
|
||||
and isinstance(module, current_module.__class__) \
|
||||
@@ -1728,13 +1738,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
f"(same type {type(current_module).__name__}, new instance)"
|
||||
)
|
||||
|
||||
# save modular_model_index.json config
|
||||
# update modular_model_index.json config
|
||||
self.register_to_config(**register_dict)
|
||||
# update component spec
|
||||
self._component_specs[name] = new_component_spec
|
||||
# finally set models
|
||||
setattr(self, name, module)
|
||||
if module is not None and self._component_manager is not None:
|
||||
# add to component manager if one is attached
|
||||
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
|
||||
self._component_manager.add(name, module, self._collection)
|
||||
|
||||
|
||||
@@ -1758,6 +1767,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
config_dict = self.load_config(modular_repo, **kwargs)
|
||||
|
||||
for name, value in config_dict.items():
|
||||
# only update component_spec for from_pretrained components
|
||||
if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
@@ -1768,7 +1778,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
|
||||
register_components_dict = {}
|
||||
for name, component_spec in self._component_specs.items():
|
||||
register_components_dict[name] = None
|
||||
if component_spec.default_creation_method == "from_config":
|
||||
component = component_spec.create()
|
||||
else:
|
||||
component = None
|
||||
register_components_dict[name] = component
|
||||
self.register_components(**register_components_dict)
|
||||
|
||||
default_configs = {}
|
||||
@@ -1870,6 +1884,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
**kwargs: Component objects or configuration values to update:
|
||||
- Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`)
|
||||
- ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it
|
||||
|
||||
Raises:
|
||||
ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute)
|
||||
@@ -1893,22 +1908,52 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
unet=new_unet_model,
|
||||
requires_safety_checker=False
|
||||
)
|
||||
# update with ComponentSpec objects
|
||||
loader.update(
|
||||
guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config")
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
|
||||
passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)}
|
||||
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)}
|
||||
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
|
||||
|
||||
for name, component in passed_components.items():
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||
|
||||
# YiYi TODO: remove this if we remove support for non config mixin components in `create()` method
|
||||
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
|
||||
raise ValueError(
|
||||
f"The passed component '{name}' is not supported in update() method "
|
||||
f"because it is not supported in `ComponentSpec.from_component()`. "
|
||||
f"Please pass a ComponentSpec object instead."
|
||||
)
|
||||
current_component_spec = self._component_specs[name]
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint):
|
||||
logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}")
|
||||
# update _component_specs based on the new component
|
||||
new_component_spec = ComponentSpec.from_component(name, component)
|
||||
self._component_specs[name] = new_component_spec
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
|
||||
self.register_components(**passed_components)
|
||||
created_components = {}
|
||||
for name, component_spec in passed_component_specs.items():
|
||||
if component_spec.default_creation_method == "from_pretrained":
|
||||
raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method")
|
||||
created_components[name] = component_spec.create()
|
||||
current_component_spec = self._component_specs[name]
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint):
|
||||
logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}")
|
||||
# update _component_specs based on the user passed component_spec
|
||||
self._component_specs[name] = component_spec
|
||||
self.register_components(**passed_components, **created_components)
|
||||
|
||||
|
||||
config_to_register = {}
|
||||
@@ -1932,8 +1977,9 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc.
|
||||
"""
|
||||
# if not specific name, load all the components with default_creation_method == "from_pretrained"
|
||||
if component_names is None:
|
||||
component_names = list(self._component_specs.keys())
|
||||
component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"]
|
||||
elif not isinstance(component_names, list):
|
||||
component_names = [component_names]
|
||||
|
||||
@@ -1958,7 +2004,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
# check if the default is specified
|
||||
component_load_kwargs[key] = value["default"]
|
||||
try:
|
||||
components_to_register[name] = spec.create(**component_load_kwargs)
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create component '{name}': {e}")
|
||||
|
||||
@@ -1986,7 +2032,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
|
||||
config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
||||
expected_component = set(config_dict.pop("_components_names"))
|
||||
@@ -2010,7 +2056,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
# append a empty component spec for these not in modular_model_index
|
||||
component_specs.append(ComponentSpec(name=name, default_creation_method="from_config"))
|
||||
return cls(component_specs + config_specs)
|
||||
return cls(component_specs + config_specs, component_manager=component_manager, collection=collection)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user