mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
style
This commit is contained in:
@@ -21,13 +21,12 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import ModelHook
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
from ..hooks import ModelHook
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
@@ -69,6 +68,7 @@ class CustomOffloadHook(ModelHook):
|
||||
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||
GPU 0 if there is a GPU, and finally to the CPU.
|
||||
"""
|
||||
|
||||
no_grad = False
|
||||
|
||||
def __init__(
|
||||
@@ -541,10 +541,9 @@ class ComponentsManager:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
Reference in New Issue
Block a user