1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Group offloading behaviour when using streams (#11097)

* update

* update
This commit is contained in:
Aryan
2025-03-18 14:44:10 +05:30
committed by GitHub
parent cb1b8b21b8
commit 3be6706018

View File

@@ -181,6 +181,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
@@ -192,14 +199,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# For the first forward pass, we have to load in a blocking manner
group_offloading_hook.group.non_blocking = False
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name)
@@ -229,6 +230,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
@@ -236,8 +238,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# Apply lazy prefetching by setting required attributes
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for hook in group_offloading_hooks:
hook.group.non_blocking = True
# Set required attributes for prefetching
if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group