1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into chroma-docs

This commit is contained in:
Sayak Paul
2025-06-19 20:19:14 +05:30
committed by GitHub
9 changed files with 150 additions and 16 deletions

View File

@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
</Tip>
### Offloading to disk
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
## Layerwise casting
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.

View File

@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
@@ -59,6 +61,7 @@ class ModuleGroup:
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -72,7 +75,26 @@ class ModuleGroup:
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
@@ -124,6 +146,30 @@ class ModuleGroup:
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -169,6 +215,26 @@ class ModuleGroup:
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
return
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self.group = group
self.next_group = next_group
@@ -363,6 +425,7 @@ def apply_group_offloading(
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -401,6 +464,9 @@ def apply_group_offloading(
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
@@ -458,6 +524,7 @@ def apply_group_offloading(
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
@@ -468,6 +535,7 @@ def apply_group_offloading(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,

View File

@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
module=self,
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
use_stream=use_stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)
def save_pretrained(

View File

@@ -15,6 +15,7 @@
import copy
import gc
import glob
import inspect
import json
import os
@@ -1693,6 +1694,35 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)

View File

@@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -151,7 +152,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
control_image = torch.randn(1, 3, 32, 32, generator=generator)
control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"prompt": "",
"negative_prompt": "",

View File

@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -137,7 +138,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn(1, 3, 32, 32, generator=generator)
image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"prompt": "",
"image": image,

View File

@@ -199,3 +199,15 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@unittest.skip("Batching is not yet supported with this pipeline")
def test_inference_batch_single_identical(self):
return super().test_inference_batch_single_identical()
@unittest.skip(
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
)
def test_float16_inference(self):
pass
@unittest.skip(
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
)
def test_save_load_float16(self):
pass