mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Modular] update the collection behavior (#11963)
* only remove from the collection
This commit is contained in:
@@ -386,6 +386,7 @@ class ComponentsManager:
|
||||
id(component) is Python's built-in unique identifier for the object
|
||||
"""
|
||||
component_id = f"{name}_{id(component)}"
|
||||
is_new_component = True
|
||||
|
||||
# check for duplicated components
|
||||
for comp_id, comp in self.components.items():
|
||||
@@ -394,6 +395,7 @@ class ComponentsManager:
|
||||
if comp_name == name:
|
||||
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
|
||||
component_id = comp_id
|
||||
is_new_component = False
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -426,7 +428,9 @@ class ComponentsManager:
|
||||
logger.warning(
|
||||
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
|
||||
)
|
||||
self.remove(comp_id)
|
||||
# remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
|
||||
self.remove_from_collection(comp_id, collection)
|
||||
|
||||
self.collections[collection].add(component_id)
|
||||
logger.info(
|
||||
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
|
||||
@@ -434,11 +438,29 @@ class ComponentsManager:
|
||||
else:
|
||||
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
if self._auto_offload_enabled and is_new_component:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
return component_id
|
||||
|
||||
def remove_from_collection(self, component_id: str, collection: str):
|
||||
"""
|
||||
Remove a component from a collection.
|
||||
"""
|
||||
if collection not in self.collections:
|
||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||
return
|
||||
if component_id not in self.collections[collection]:
|
||||
logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
|
||||
return
|
||||
# remove from the collection
|
||||
self.collections[collection].remove(component_id)
|
||||
# check if this component is in any other collection
|
||||
comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
|
||||
if not comp_colls: # only if no other collection contains this component, remove it
|
||||
logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
|
||||
self.remove(component_id)
|
||||
|
||||
def remove(self, component_id: str = None):
|
||||
"""
|
||||
Remove a component from the ComponentsManager.
|
||||
|
||||
@@ -185,6 +185,8 @@ class ComponentSpec:
|
||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
||||
segments).
|
||||
"""
|
||||
if self.default_creation_method == "from_config":
|
||||
return "null"
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
|
||||
Reference in New Issue
Block a user