mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into chroma-docs
This commit is contained in:
@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
|
||||
|
||||
</Tip>
|
||||
|
||||
### Offloading to disk
|
||||
|
||||
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
|
||||
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
|
||||
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
|
||||
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
|
||||
|
||||
## Layerwise casting
|
||||
|
||||
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
|
||||
|
||||
@@ -12,9 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
@@ -59,6 +61,7 @@ class ModuleGroup:
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
@@ -72,7 +75,26 @@ class ModuleGroup:
|
||||
self.record_stream = record_stream
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
self.offload_to_disk_path = offload_to_disk_path
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
|
||||
|
||||
all_tensors = []
|
||||
for module in self.modules:
|
||||
all_tensors.extend(list(module.parameters()))
|
||||
all_tensors.extend(list(module.buffers()))
|
||||
all_tensors.extend(self.parameters)
|
||||
all_tensors.extend(self.buffers)
|
||||
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
||||
|
||||
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
||||
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
||||
self.cpu_param_dict = {}
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
if self.stream is None and self.record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `stream` is None.")
|
||||
@@ -124,6 +146,30 @@ class ModuleGroup:
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
# Load directly to the target device (synchronous)
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
@@ -169,6 +215,26 @@ class ModuleGroup:
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.offload_to_disk_path:
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
# we perform a write.
|
||||
# Check if the file has been saved in this session or if it already exists on disk.
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {
|
||||
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
|
||||
}
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
|
||||
# The group is now considered offloaded to disk for the rest of the session.
|
||||
self._is_offloaded_to_disk = True
|
||||
|
||||
# We do this to free up the RAM which is still holding the up tensor data.
|
||||
for tensor_obj in self.tensor_to_key.keys():
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
return
|
||||
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
) -> None:
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
|
||||
@@ -363,6 +425,7 @@ def apply_group_offloading(
|
||||
use_stream: bool = False,
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -401,6 +464,9 @@ def apply_group_offloading(
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
num_blocks_per_group (`int`, *optional*):
|
||||
The number of blocks per group when using offload_type="block_level". This is required when using
|
||||
offload_type="block_level".
|
||||
@@ -458,6 +524,7 @@ def apply_group_offloading(
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
@@ -468,6 +535,7 @@ def apply_group_offloading(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device=onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
|
||||
@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
use_stream: bool = False,
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self,
|
||||
onload_device,
|
||||
offload_device,
|
||||
offload_type,
|
||||
num_blocks_per_group,
|
||||
non_blocking,
|
||||
use_stream,
|
||||
record_stream,
|
||||
module=self,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
use_stream=use_stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -1693,6 +1694,35 @@ class ModelTesterMixin:
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
|
||||
@require_torch_accelerator
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_disk(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
@@ -151,7 +152,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
control_image = torch.randn(1, 3, 32, 32, generator=generator)
|
||||
control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
@@ -137,7 +138,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = torch.randn(1, 3, 32, 32, generator=generator)
|
||||
image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"image": image,
|
||||
|
||||
@@ -199,3 +199,15 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Batching is not yet supported with this pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
|
||||
@unittest.skip(
|
||||
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
|
||||
)
|
||||
def test_float16_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user