mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
use diffusers ModelHook, raise a import error for accelerate inside enable_auto_cpu_offload
This commit is contained in:
@@ -26,9 +26,11 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
|
||||
from ..hooks import ModelHook
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
@@ -67,6 +69,7 @@ class CustomOffloadHook(ModelHook):
|
||||
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||
GPU 0 if there is a GPU, and finally to the CPU.
|
||||
"""
|
||||
no_grad = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -538,6 +541,10 @@ class ComponentsManager:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
Reference in New Issue
Block a user