mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
updates.patch
This commit is contained in:
@@ -12,12 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
import os
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -165,9 +166,10 @@ class ModuleGroup:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
# Load directly to the target device (synchronous)
|
||||
loaded_tensors = safetensors.torch.load_file(
|
||||
self.safetensors_file_path, device=self.onload_device
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
@@ -265,16 +267,12 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None
|
||||
) -> None:
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
# map param/buffer name -> file path
|
||||
self.param_to_path: Dict[str,str] = {}
|
||||
self.buffer_to_path: Dict[str,str] = {}
|
||||
self.param_to_path: Dict[str, str] = {}
|
||||
self.buffer_to_path: Dict[str, str] = {}
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -516,7 +514,6 @@ def apply_group_offloading(
|
||||
stream = torch.Stream()
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
|
||||
if offload_to_disk and offload_path is None:
|
||||
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
|
||||
|
||||
@@ -899,4 +896,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
|
||||
@@ -543,6 +543,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
offload_to_disk: bool = False,
|
||||
offload_path: Optional[str] = None,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -588,15 +590,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self,
|
||||
onload_device,
|
||||
offload_device,
|
||||
offload_type,
|
||||
num_blocks_per_group,
|
||||
non_blocking,
|
||||
use_stream,
|
||||
record_stream,
|
||||
module=self,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
use_stream=use_stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk=offload_to_disk,
|
||||
offload_path=offload_path,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user