mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] introduce _no_split_modules to ModelMixin (#6396)
* introduce _no_split_modules. * unnecessary spaces. * remove unnecessary kwargs and style * fix: accelerate imports. * change to _determine_device_map * add the blocks that have residual connections. * add: CrossAttnUpBlock2D * add: testin * style * line-spaces * quality * add disk offload test without safetensors. * checking disk offloading percentages. * change model split * add: utility for checking multi-gpu requirement. * model parallelism test * splits. * splits. * splits * splits. * splits. * splits. * offload folder to test_disk_offload_with_safetensors * add _no_split_modules * fix-copies
This commit is contained in:
@@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -57,7 +57,8 @@ else:
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
@@ -99,6 +100,29 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py)
|
||||
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
|
||||
if isinstance(device_map, str):
|
||||
no_split_modules = model._get_no_split_modules(device_map)
|
||||
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
||||
|
||||
if device_map != "sequential":
|
||||
max_memory = get_balanced_memory(
|
||||
model,
|
||||
dtype=torch_dtype,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
max_memory=max_memory,
|
||||
**device_map_kwargs,
|
||||
)
|
||||
else:
|
||||
max_memory = get_max_memory(max_memory)
|
||||
|
||||
device_map_kwargs["max_memory"] = max_memory
|
||||
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
@@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
_supports_gradient_checkpointing = False
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
_no_split_modules = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -560,6 +585,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# change device_map into a map if we passed an int, a str or a torch.device
|
||||
if isinstance(device_map, torch.device):
|
||||
device_map = {"": device_map}
|
||||
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
try:
|
||||
device_map = {"": torch.device(device_map)}
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
||||
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
||||
)
|
||||
elif isinstance(device_map, int):
|
||||
if device_map < 0:
|
||||
raise ValueError(
|
||||
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
||||
)
|
||||
else:
|
||||
device_map = {"": device_map}
|
||||
|
||||
if device_map is not None:
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not low_cpu_mem_usage:
|
||||
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if device_map is not None and not is_torch_version(">=", "1.10"):
|
||||
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
||||
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
@@ -582,10 +637,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -690,6 +741,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
@@ -881,6 +933,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
# Adapted from `transformers` modeling_utils.py
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
"""
|
||||
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
||||
get the underlying `_no_split_modules`.
|
||||
|
||||
Args:
|
||||
device_map (`str`):
|
||||
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
||||
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
"""
|
||||
_no_split_modules = set()
|
||||
modules_to_check = [self]
|
||||
while len(modules_to_check) > 0:
|
||||
module = modules_to_check.pop(-1)
|
||||
# if the module does not appear in _no_split_modules, we also check the children
|
||||
if module.__class__.__name__ not in _no_split_modules:
|
||||
if isinstance(module, ModelMixin):
|
||||
if module._no_split_modules is None:
|
||||
raise ValueError(
|
||||
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
||||
"class needs to implement the `_no_split_modules` attribute."
|
||||
)
|
||||
else:
|
||||
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
||||
modules_to_check += list(module.children())
|
||||
return list(_no_split_modules)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
|
||||
@@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -161,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from accelerate.utils import compute_module_sizes
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from requests.exceptions import HTTPError
|
||||
@@ -39,6 +40,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_torch_2,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
)
|
||||
@@ -200,6 +202,21 @@ class ModelTesterMixin:
|
||||
main_input_name = None # overwrite in model specific tester class
|
||||
base_precision = 1e-3
|
||||
forward_requires_fresh_args = False
|
||||
model_split_percents = [0.5, 0.7, 0.9]
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
else:
|
||||
self.assertEqual(param.device, torch.device(param_device))
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
@@ -670,6 +687,117 @@ class ModelTesterMixin:
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_cpu_offload(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
max_size = int(self.model_split_percents[1] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
# This errors out because it's missing an offload folder
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
max_size = int(self.model_split_percents[1] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
max_size = int(self.model_split_percents[1] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
|
||||
)
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallelism(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
model_split_percents = [0.5, 0.3, 0.4]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
Reference in New Issue
Block a user