From 1ecfbfe12b330a34e3e7893d794b469b9eef5b02 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 14 Jan 2026 15:59:36 +0000 Subject: [PATCH] `disable_mmap` in pipeline `from_pretrained` (#12854) * update * `disable_mmap` in `from_pretrained` --------- Co-authored-by: DN6 --- src/diffusers/models/model_loading_utils.py | 5 ++++- src/diffusers/models/modeling_utils.py | 3 +++ src/diffusers/pipelines/pipeline_loading_utils.py | 4 ++++ src/diffusers/pipelines/pipeline_utils.py | 5 +++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 08b3f0234f..6d2e8df9c2 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -355,8 +355,9 @@ def _load_shard_file( state_dict_folder=None, ignore_mismatched_sizes=False, low_cpu_mem_usage=False, + disable_mmap=False, ): - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, @@ -402,6 +403,7 @@ def _load_shard_files_with_threadpool( state_dict_folder=None, ignore_mismatched_sizes=False, low_cpu_mem_usage=False, + disable_mmap=False, ): # Do not spawn anymore workers than you need num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS) @@ -428,6 +430,7 @@ def _load_shard_files_with_threadpool( state_dict_folder=state_dict_folder, ignore_mismatched_sizes=ignore_mismatched_sizes, low_cpu_mem_usage=low_cpu_mem_usage, + disable_mmap=disable_mmap, ) tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"} diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 482c6d0103..0ccd4c480e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1306,6 +1306,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, is_parallel_loading_enabled=is_parallel_loading_enabled, + disable_mmap=disable_mmap, ) loading_info = { "missing_keys": missing_keys, @@ -1591,6 +1592,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): offload_folder: Optional[Union[str, os.PathLike]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, is_parallel_loading_enabled: Optional[bool] = False, + disable_mmap: bool = False, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1659,6 +1661,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): state_dict_folder=state_dict_folder, ignore_mismatched_sizes=ignore_mismatched_sizes, low_cpu_mem_usage=low_cpu_mem_usage, + disable_mmap=disable_mmap, ) if is_parallel_loading_enabled: diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index bda80e2fe2..57d4eaa8f8 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -758,6 +758,7 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + disable_mmap: bool, quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -859,6 +860,9 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if is_diffusers_model: + loading_kwargs["disable_mmap"] = disable_mmap + if is_transformers_model and is_transformers_version(">=", "4.57.0"): loading_kwargs.pop("offload_state_dict") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c4f118d7e1..b96305c741 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -721,6 +721,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): loading `from_flax`. dduf_file(`str`, *optional*): Load weights from the specified dduf file. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf > auth login`. @@ -772,6 +775,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -1059,6 +1063,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + disable_mmap=disable_mmap, quantization_config=quantization_config, ) logger.info(