mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update the logic of is_sequential_cpu_offload (#7788)
* up * add comment to the tests + fix dit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1304,7 +1304,11 @@ class DemoFusionSDXLPipeline(
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
|
||||
@@ -369,7 +369,11 @@ class LoraLoaderMixin:
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
|
||||
@@ -423,7 +423,11 @@ class TextualInversionLoaderMixin:
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
|
||||
@@ -359,7 +359,11 @@ class UNet2DConditionLoadersMixin:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
|
||||
@@ -227,6 +227,9 @@ class DiTPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
samples = self.numpy_to_pil(samples)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (samples,)
|
||||
|
||||
|
||||
@@ -376,7 +376,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
or hasattr(module._hf_hook, "hooks")
|
||||
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
@@ -1005,8 +1009,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
for _, model in self.components.items():
|
||||
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
||||
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
|
||||
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
|
||||
accelerate.hooks.remove_hook_from_module(model, recurse=True)
|
||||
self._all_hooks = []
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
|
||||
@@ -324,10 +324,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
|
||||
# PixArt transformer model does not work with sequential offload so skip it for now
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -308,10 +308,6 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
|
||||
# PixArt transformer model does not work with sequential offload so skip it for now
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -1360,6 +1360,8 @@ class PipelineTesterMixin:
|
||||
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
|
||||
)
|
||||
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
|
||||
import accelerate
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
@@ -1373,6 +1375,7 @@ class PipelineTesterMixin:
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
assert pipe._execution_device.type == pipe._offload_device.type
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
@@ -1380,11 +1383,48 @@ class PipelineTesterMixin:
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
|
||||
|
||||
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
}
|
||||
# 1. all offloaded modules should be saved to cpu and moved to meta device
|
||||
self.assertTrue(
|
||||
all(v.device.type == "meta" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
|
||||
)
|
||||
# 2. all offloaded modules should have hook installed
|
||||
self.assertTrue(
|
||||
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
|
||||
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
|
||||
)
|
||||
# 3. all offloaded modules should have correct hooks installed, should be either one of these two
|
||||
# - `AlignDevicesHook`
|
||||
# - a SequentialHook` that contains `AlignDevicesHook`
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook"):
|
||||
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
|
||||
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
|
||||
for hook in v._hf_hook.hooks:
|
||||
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
|
||||
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
|
||||
self.assertTrue(
|
||||
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
|
||||
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
|
||||
)
|
||||
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
|
||||
import accelerate
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -1400,19 +1440,39 @@ class PipelineTesterMixin:
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
assert pipe._execution_device.type == pipe._offload_device.type
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
|
||||
offloaded_modules = [
|
||||
v
|
||||
|
||||
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
]
|
||||
(
|
||||
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
||||
}
|
||||
# 1. check if all offloaded modules are saved to cpu
|
||||
self.assertTrue(
|
||||
all(v.device.type == "cpu" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
|
||||
)
|
||||
# 2. check if all offloaded modules have hooks installed
|
||||
self.assertTrue(
|
||||
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
|
||||
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
|
||||
)
|
||||
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
|
||||
self.assertTrue(
|
||||
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
@@ -1444,16 +1504,24 @@ class PipelineTesterMixin:
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
||||
)
|
||||
|
||||
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
}
|
||||
# 1. check if all offloaded modules are saved to cpu
|
||||
self.assertTrue(
|
||||
all(v.device.type == "cpu" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
|
||||
)
|
||||
|
||||
# 2. check if all offloaded modules have hooks installed
|
||||
self.assertTrue(
|
||||
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
|
||||
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
|
||||
)
|
||||
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
|
||||
@@ -1493,19 +1561,36 @@ class PipelineTesterMixin:
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
||||
)
|
||||
|
||||
# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
}
|
||||
# 1. check if all offloaded modules are moved to meta device
|
||||
self.assertTrue(
|
||||
all(v.device.type == "meta" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
|
||||
)
|
||||
# 2. check if all offloaded modules have hook installed
|
||||
self.assertTrue(
|
||||
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
|
||||
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
|
||||
)
|
||||
# 3. check if all offloaded modules have correct hooks installed, should be either one of these two
|
||||
# - `AlignDevicesHook`
|
||||
# - a SequentialHook` that contains `AlignDevicesHook`
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
if hasattr(v, "_hf_hook"):
|
||||
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
|
||||
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
|
||||
for hook in v._hf_hook.hooks:
|
||||
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
|
||||
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
|
||||
self.assertTrue(
|
||||
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||
|
||||
Reference in New Issue
Block a user