From f238cb0736e9daa7d35ad9c7daa2b048d0076aa1 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 28 Mar 2024 08:23:02 -1000 Subject: [PATCH] cpu_offload: remove all hooks before offload (#7448) * add remove_all_hooks * a few more fix and tests * up * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Pedro Cuenca * split tests * add --------- Co-authored-by: Pedro Cuenca --- src/diffusers/pipelines/pipeline_utils.py | 32 ++++---- tests/pipelines/test_pipelines_common.py | 92 +++++++++++++++++++++++ 2 files changed, 110 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index bee65aff57..f59c25c191 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -371,9 +371,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False - return hasattr(module, "_hf_hook") and not isinstance( - module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook) - ) + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): @@ -939,6 +937,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): return torch.device(module._hf_hook.execution_device) return self.device + def remove_all_hooks(self): + r""" + Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`. + """ + for _, model in self.components.items(): + if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): + is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook) + accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload) + self._all_hooks = [] + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -963,6 +971,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + self.remove_all_hooks() + torch_device = torch.device(device) device_index = torch_device.index @@ -979,15 +989,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): device = torch.device(f"{device_type}:{self._offload_gpu_id}") self._offload_device = device - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, self.device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + self.to("cpu", silence_dtype_warnings=True) + device_mod = getattr(torch, device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} - self._all_hooks = [] hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) @@ -1021,11 +1029,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): # `enable_model_cpu_offload` has not be called, so silently do nothing return - for hook in self._all_hooks: - # offload model and remove hook from model - hook.offload() - hook.remove() - # make sure the model is in the same state as before calling it self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) @@ -1048,6 +1051,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + self.remove_all_hooks() torch_device = torch.device(device) device_index = torch_device.index diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 13007a2aa1..41292ec963 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1107,6 +1107,98 @@ class PipelineTesterMixin: f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", ) + @unittest.skipIf( + torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), + reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", + ) + def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): + import accelerate + + generator_device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.set_progress_bar_config(disable=None) + + pipe.enable_model_cpu_offload() + inputs = self.get_dummy_inputs(generator_device) + output_with_offload = pipe(**inputs)[0] + + pipe.enable_model_cpu_offload() + inputs = self.get_dummy_inputs(generator_device) + output_with_offload_twice = pipe(**inputs)[0] + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max() + self.assertLess( + max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results" + ) + offloaded_modules = [ + v + for k, v in pipe.components.items() + if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload + ] + ( + self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), + f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + ) + + offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")] + ( + self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)), + f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}", + ) + + @unittest.skipIf( + torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), + reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", + ) + def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): + import accelerate + + generator_device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.set_progress_bar_config(disable=None) + + pipe.enable_sequential_cpu_offload() + inputs = self.get_dummy_inputs(generator_device) + output_with_offload = pipe(**inputs)[0] + + pipe.nable_sequential_cpu_offload() + inputs = self.get_dummy_inputs(generator_device) + output_with_offload_twice = pipe(**inputs)[0] + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_with_offload_twice)).max() + self.assertLess( + max_diff, expected_max_diff, "running sequential offloading second time should have the inference results" + ) + offloaded_modules = [ + v + for k, v in pipe.components.items() + if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload + ] + ( + self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)), + f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}", + ) + + offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")] + ( + self.assertTrue( + all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks) + ), + f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed",