mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
cpu_offload: remove all hooks before offload (#7448)
* add remove_all_hooks * a few more fix and tests * up * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * split tests * add --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -371,9 +371,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and not isinstance(
|
||||
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
@@ -939,6 +937,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def remove_all_hooks(self):
|
||||
r"""
|
||||
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
|
||||
"""
|
||||
for _, model in self.components.items():
|
||||
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
||||
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
|
||||
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
|
||||
self._all_hooks = []
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
@@ -963,6 +971,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
self.remove_all_hooks()
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
@@ -979,15 +989,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
device_mod = getattr(torch, self.device.type, None)
|
||||
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
||||
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
device_mod = getattr(torch, device.type, None)
|
||||
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
||||
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
|
||||
self._all_hooks = []
|
||||
hook = None
|
||||
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||
model = all_model_components.pop(model_str, None)
|
||||
@@ -1021,11 +1029,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
||||
return
|
||||
|
||||
for hook in self._all_hooks:
|
||||
# offload model and remove hook from model
|
||||
hook.offload()
|
||||
hook.remove()
|
||||
|
||||
# make sure the model is in the same state as before calling it
|
||||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
||||
|
||||
@@ -1048,6 +1051,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
@@ -1107,6 +1107,98 @@ class PipelineTesterMixin:
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
|
||||
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
|
||||
)
|
||||
def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
|
||||
import accelerate
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload_twice = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
||||
)
|
||||
offloaded_modules = [
|
||||
v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
]
|
||||
(
|
||||
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
||||
)
|
||||
|
||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
||||
(
|
||||
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
|
||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
|
||||
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
|
||||
)
|
||||
def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
|
||||
import accelerate
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.nable_sequential_cpu_offload()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload_twice = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max()
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
||||
)
|
||||
offloaded_modules = [
|
||||
v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
]
|
||||
(
|
||||
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
|
||||
)
|
||||
|
||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
||||
(
|
||||
self.assertTrue(
|
||||
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
|
||||
),
|
||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
Reference in New Issue
Block a user