mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
disable_mmap in pipeline from_pretrained (#12854)
* update * `disable_mmap` in `from_pretrained` --------- Co-authored-by: DN6 <dhruv.nair@gmail.com>
This commit is contained in:
@@ -355,8 +355,9 @@ def _load_shard_file(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
@@ -402,6 +403,7 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
||||
@@ -428,6 +430,7 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
|
||||
@@ -1306,6 +1306,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1591,6 +1592,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
is_parallel_loading_enabled: Optional[bool] = False,
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
@@ -1659,6 +1661,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if is_parallel_loading_enabled:
|
||||
|
||||
@@ -758,6 +758,7 @@ def load_sub_model(
|
||||
use_safetensors: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
quantization_config: Optional[Any] = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
@@ -859,6 +860,9 @@ def load_sub_model(
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["disable_mmap"] = disable_mmap
|
||||
|
||||
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
|
||||
loading_kwargs.pop("offload_state_dict")
|
||||
|
||||
|
||||
@@ -721,6 +721,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
loading `from_flax`.
|
||||
dduf_file(`str`, *optional*):
|
||||
Load weights from the specified dduf file.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
|
||||
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
|
||||
with `hf > auth login`.
|
||||
@@ -772,6 +775,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
@@ -1059,6 +1063,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user