diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a31acb5a2d..ce6f47f67a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union -import os -import torch import safetensors.torch +import torch + from ..utils import get_logger, is_accelerate_available from .hooks import HookRegistry, ModelHook @@ -165,9 +166,10 @@ class ModuleGroup: tensor_obj.data.record_stream(current_stream) else: # Load directly to the target device (synchronous) - loaded_tensors = safetensors.torch.load_file( - self.safetensors_file_path, device=self.onload_device + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] return @@ -265,16 +267,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__( - self, - group: ModuleGroup, - next_group: Optional[ModuleGroup] = None - ) -> None: + def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: self.group = group self.next_group = next_group # map param/buffer name -> file path - self.param_to_path: Dict[str,str] = {} - self.buffer_to_path: Dict[str,str] = {} + self.param_to_path: Dict[str, str] = {} + self.buffer_to_path: Dict[str, str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -516,7 +514,6 @@ def apply_group_offloading( stream = torch.Stream() else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") - if offload_to_disk and offload_path is None: raise ValueError("`offload_path` must be set when `offload_to_disk=True`.") @@ -899,4 +896,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: for submodule in module.modules(): if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device - raise ValueError("Group offloading is not enabled for the provided module.") \ No newline at end of file + raise ValueError("Group offloading is not enabled for the provided module.") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79f..ce57d17ab0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,6 +543,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk: bool = False, + offload_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -588,15 +590,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, - onload_device, - offload_device, - offload_type, - num_blocks_per_group, - non_blocking, - use_stream, - record_stream, + module=self, + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + use_stream=use_stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk=offload_to_disk, + offload_path=offload_path, ) def save_pretrained(