From 3fd31eef518b73ee592f82435f3d370a716ead4f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Apr 2024 08:46:51 +0530 Subject: [PATCH] [Core] introduce _no_split_modules to `ModelMixin` (#6396) * introduce _no_split_modules. * unnecessary spaces. * remove unnecessary kwargs and style * fix: accelerate imports. * change to _determine_device_map * add the blocks that have residual connections. * add: CrossAttnUpBlock2D * add: testin * style * line-spaces * quality * add disk offload test without safetensors. * checking disk offloading percentages. * change model split * add: utility for checking multi-gpu requirement. * model parallelism test * splits. * splits. * splits * splits. * splits. * splits. * offload folder to test_disk_offload_with_safetensors * add _no_split_modules * fix-copies --- .../models/autoencoders/autoencoder_kl.py | 1 + src/diffusers/models/modeling_utils.py | 92 ++++++++++++- .../models/transformers/transformer_2d.py | 1 + .../models/unets/unet_2d_condition.py | 1 + .../versatile_diffusion/modeling_text_unet.py | 1 + tests/models/test_modeling_common.py | 128 ++++++++++++++++++ .../unets/test_models_unet_2d_condition.py | 2 + 7 files changed, 221 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index b286453de4..0b9b9d4d47 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c1fdff8ab3..8d9f2d9e71 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -57,7 +57,8 @@ else: if is_accelerate_available(): import accelerate - from accelerate.utils import set_module_tensor_to_device + from accelerate import infer_auto_device_map + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device from accelerate.utils.versions import is_torch_version @@ -99,6 +100,29 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: return first_tuple[1].dtype +# Adapted from `transformers` (see modeling_utils.py) +def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype): + if isinstance(device_map, str): + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=torch_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + + device_map_kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + + return device_map + + def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. @@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None def __init__(self): super().__init__() @@ -560,6 +585,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if device_map is not None and not is_torch_version(">=", "1.10"): + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path @@ -582,10 +637,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): token=token, revision=revision, subfolder=subfolder, - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) @@ -690,6 +741,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU + device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) try: accelerate.load_checkpoint_and_dispatch( model, @@ -881,6 +933,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + # Adapted from `transformers` modeling_utils.py + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, ModelMixin): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + @property def device(self) -> torch.device: """ diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 768fceb71a..6a2695b9e4 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 34327e1049..697730b359 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -161,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] @register_to_config def __init__( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 3c3bd52669..c84caa1fad 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] @register_to_config def __init__( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f919ba10fb..d8a93d40c8 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -24,6 +24,7 @@ from typing import Dict, List, Tuple import numpy as np import requests_mock import torch +from accelerate.utils import compute_module_sizes from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available from requests.exceptions import HTTPError @@ -39,6 +40,7 @@ from diffusers.utils.testing_utils import ( require_torch_2, require_torch_accelerator_with_training, require_torch_gpu, + require_torch_multi_gpu, run_test_in_subprocess, torch_device, ) @@ -200,6 +202,21 @@ class ModelTesterMixin: main_input_name = None # overwrite in model specific tester class base_precision = 1e-3 forward_requires_fresh_args = False + model_split_percents = [0.5, 0.7, 0.9] + + def check_device_map_is_respected(self, model, device_map): + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + self.assertEqual(param.device, torch.device("meta")) + else: + self.assertEqual(param.device, torch.device(param_device)) def test_from_save_pretrained(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: @@ -670,6 +687,117 @@ class ModelTesterMixin: " from `_deprecated_kwargs = []`" ) + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_gpu + def test_disk_offload_without_safetensors(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, safe_serialization=False) + + with self.assertRaises(ValueError): + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + # This errors out because it's missing an offload folder + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_gpu + def test_disk_offload_with_safetensors(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 1b8a998cfd..33aa6a1037 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" + # We override the items here because the unet under consideration is small. + model_split_percents = [0.5, 0.3, 0.4] @property def dummy_input(self):