mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Low cpu memory] Correct naming and improve default usage (#1122)
* correct naming * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
988c82227d
commit
42bb459457
@@ -35,6 +35,12 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@@ -278,11 +284,11 @@ class ModelMixin(torch.nn.Module):
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
fast_load (`bool`, *optional*, defaults to `True`):
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
this argument will be ignored and the model will be loaded normally.
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -311,16 +317,26 @@ class ModelMixin(torch.nn.Module):
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
fast_load = kwargs.pop("fast_load", True)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
# Fast init is only possible if torch version is >= 1.9.0
|
||||
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
|
||||
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
|
||||
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
@@ -403,7 +419,7 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
# restore default dtype
|
||||
|
||||
if _INIT_EMPTY_WEIGHTS:
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model, unused_kwargs = cls.from_config(
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
@@ -33,6 +34,7 @@ from tqdm.auto import tqdm
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -328,6 +330,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
@@ -380,7 +395,25 @@ class DiffusionPipeline(ConfigMixin):
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
fast_load = kwargs.pop("fast_load", True)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -573,17 +606,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["fast_load"] = fast_load
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_transformers_model and device_map is None:
|
||||
loading_kwargs["low_cpu_mem_usage"] = fast_load
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
|
||||
@@ -133,7 +133,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||
# by defautl model loading will use accelerate as `fast_load=True`
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
@@ -156,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
|
||||
)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
@@ -170,7 +170,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
# by defautl model loading will use accelerate as `fast_load=True`
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
@@ -181,7 +181,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
|
||||
)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
|
||||
@@ -823,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
||||
def test_stable_diffusion_fast_load(self):
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_fast_load.to(torch_device)
|
||||
fast_load_time = time.time() - start_time
|
||||
pipeline_low_cpu_mem_usage.to(torch_device)
|
||||
low_cpu_mem_usage_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
|
||||
)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * fast_load_time < normal_load_time
|
||||
assert 2 * low_cpu_mem_usage_time < normal_load_time
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
|
||||
Reference in New Issue
Block a user