1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] introduce _no_split_modules to ModelMixin (#6396)

* introduce _no_split_modules.

* unnecessary spaces.

* remove unnecessary kwargs and style

* fix: accelerate imports.

* change to _determine_device_map

* add the blocks that have residual connections.

* add: CrossAttnUpBlock2D

* add: testin

* style

* line-spaces

* quality

* add disk offload test without safetensors.

* checking disk offloading percentages.

* change model split

* add: utility for checking multi-gpu requirement.

* model parallelism test

* splits.

* splits.

* splits

* splits.

* splits.

* splits.

* offload folder to test_disk_offload_with_safetensors

* add _no_split_modules

* fix-copies
This commit is contained in:
Sayak Paul
2024-04-30 08:46:51 +05:30
committed by GitHub
parent b02e2113ff
commit 3fd31eef51
7 changed files with 221 additions and 5 deletions

View File

@@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
@register_to_config
def __init__(

View File

@@ -57,7 +57,8 @@ else:
if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
@@ -99,6 +100,29 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
return first_tuple[1].dtype
# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
if isinstance(device_map, str):
no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**device_map_kwargs,
)
else:
max_memory = get_max_memory(max_memory)
device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
return device_map
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
def __init__(self):
super().__init__()
@@ -560,6 +585,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
if device_map is not None and not is_torch_version(">=", "1.10"):
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
@@ -582,10 +637,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
token=token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent,
**kwargs,
)
@@ -690,6 +741,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
try:
accelerate.load_checkpoint_and_dispatch(
model,
@@ -881,6 +933,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
device_map (`str`):
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
Returns:
`List[str]`: List of modules that should not be split
"""
_no_split_modules = set()
modules_to_check = [self]
while len(modules_to_check) > 0:
module = modules_to_check.pop(-1)
# if the module does not appear in _no_split_modules, we also check the children
if module.__class__.__name__ not in _no_split_modules:
if isinstance(module, ModelMixin):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
modules_to_check += list(module.children())
return list(_no_split_modules)
@property
def device(self) -> torch.device:
"""

View File

@@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -161,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
@register_to_config
def __init__(

View File

@@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
@register_to_config
def __init__(

View File

@@ -24,6 +24,7 @@ from typing import Dict, List, Tuple
import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from requests.exceptions import HTTPError
@@ -39,6 +40,7 @@ from diffusers.utils.testing_utils import (
require_torch_2,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_device,
)
@@ -200,6 +202,21 @@ class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
forward_requires_fresh_args = False
model_split_percents = [0.5, 0.7, 0.9]
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
else:
self.assertEqual(param.device, torch.device(param_device))
def test_from_save_pretrained(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
@@ -670,6 +687,117 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):

View File

@@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.3, 0.4]
@property
def dummy_input(self):