mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Core] add "balanced" device_map support to pipelines (#6857)
* get device <-> component mapping when using multiple gpus. * condition the device_map bits. * relax condition * device_map progress. * device_map enhancement * some cleaning up and debugging * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * incorporate suggestions from PR. * remove multi-gpu condition for now. * guard check the component -> device mapping * fix: device_memory variable * dispatching transformers model to have force_hooks=True * better guarding for transformers device_map * introduce support balanced_low_memory and balanced_ultra_low_memory. * remove device_map patch. * fix: intermediate variable scoping. * fix: condition in cpu offload. * fix: flax class restrictions. * remove modifications from cpu_offload and model_offload * incorporate changes. * add a simple forward pass test * add: torch_device in get_inputs() * add: tests * remove print * safe-guard to(), model offloading and cpu offloading when balanced is used as a device_map. * style * remove . * safeguard device_map with more checks and remove invalid device_mapping strategues. * make a class attribute and adjust tests accordingly. * fix device_map check * fix test * adjust comment * fix: device_map attribute * fix: dispatching. * max_memory test for pipeline * version guard the tests * fix guard. * address review feedback. * reset_device_map method. * add: test for reset_hf_device_map * fix a couple things. * add reset_device_map() in the error message. * add tests for checking reset_device_map doesn't have unintended consequences. * fix reset_device_map and offloading tests. * create _get_final_device_map utility. * hf_device_map -> _hf_device_map * add documentation * add notes suggested by Marc. * styling. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * move updates within gpu condition. * other docs related things * note on ignore a device not specified in . * provide a suggestion if device mapping errors out. * fix: typo. * _hf_device_map -> hf_device_map * Empty-Commit * add: example hf_device_map. --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -699,6 +699,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=True,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
|
||||
@@ -22,15 +22,19 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
model_info,
|
||||
)
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -44,9 +48,12 @@ if is_transformers_available():
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
from accelerate.utils import compute_module_sizes, get_max_memory
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
@@ -376,6 +383,207 @@ def _get_pipeline_class(
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def _load_empty_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
name: str,
|
||||
torch_dtype: Union[str, torch.dtype],
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve class objects.
|
||||
class_obj, _ = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
# Determine library.
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
model = None
|
||||
config_path = cached_folder
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if is_diffusers_model:
|
||||
# Load config and then the model on meta.
|
||||
config, unused_kwargs, commit_hash = class_obj.load_config(
|
||||
os.path.join(config_path, name),
|
||||
cache_dir=cached_folder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj.from_config(config, **unused_kwargs)
|
||||
elif is_transformers_model:
|
||||
config_class = getattr(class_obj, "config_class", None)
|
||||
if config_class is None:
|
||||
raise ValueError("`config_class` cannot be None. Please double-check the model.")
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
cached_folder,
|
||||
subfolder=name,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj(config)
|
||||
|
||||
if model is not None:
|
||||
model = model.to(dtype=torch_dtype)
|
||||
return model
|
||||
|
||||
|
||||
def _assign_components_to_devices(
|
||||
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
|
||||
):
|
||||
device_ids = list(device_memory.keys())
|
||||
device_cycle = device_ids + device_ids[::-1]
|
||||
device_memory = device_memory.copy()
|
||||
|
||||
device_id_component_mapping = {}
|
||||
current_device_index = 0
|
||||
for component in module_sizes:
|
||||
device_id = device_cycle[current_device_index % len(device_cycle)]
|
||||
component_memory = module_sizes[component]
|
||||
curr_device_memory = device_memory[device_id]
|
||||
|
||||
# If the GPU doesn't fit the current component offload to the CPU.
|
||||
if component_memory > curr_device_memory:
|
||||
device_id_component_mapping["cpu"] = [component]
|
||||
else:
|
||||
if device_id not in device_id_component_mapping:
|
||||
device_id_component_mapping[device_id] = [component]
|
||||
else:
|
||||
device_id_component_mapping[device_id].append(component)
|
||||
|
||||
# Update the device memory.
|
||||
device_memory[device_id] -= component_memory
|
||||
current_device_index += 1
|
||||
|
||||
return device_id_component_mapping
|
||||
|
||||
|
||||
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
|
||||
# To avoid circular import problem.
|
||||
from diffusers import pipelines
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
||||
|
||||
# Load each module in the pipeline on a meta device so that we can derive the device map.
|
||||
init_empty_modules = {}
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name.startswith("Flax"):
|
||||
raise ValueError("Flax pipelines are not supported with `device_map`.")
|
||||
|
||||
# Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name,
|
||||
library,
|
||||
class_name,
|
||||
importable_classes,
|
||||
passed_class_obj,
|
||||
name,
|
||||
is_pipeline_module,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
loaded_sub_model = _load_empty_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
name=name,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=kwargs.get("cached_folder", None),
|
||||
force_download=kwargs.get("force_download", None),
|
||||
resume_download=kwargs.get("resume_download", None),
|
||||
proxies=kwargs.get("proxies", None),
|
||||
local_files_only=kwargs.get("local_files_only", None),
|
||||
token=kwargs.get("token", None),
|
||||
revision=kwargs.get("revision", None),
|
||||
)
|
||||
|
||||
if loaded_sub_model is not None:
|
||||
init_empty_modules[name] = loaded_sub_model
|
||||
|
||||
# determine device map
|
||||
# Obtain a sorted dictionary for mapping the model-level components
|
||||
# to their sizes.
|
||||
module_sizes = {
|
||||
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
|
||||
for module_name, module in init_empty_modules.items()
|
||||
if isinstance(module, torch.nn.Module)
|
||||
}
|
||||
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Obtain maximum memory available per device (GPUs only).
|
||||
max_memory = get_max_memory(max_memory)
|
||||
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
|
||||
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
|
||||
|
||||
# Obtain a dictionary mapping the model-level components to the available
|
||||
# devices based on the maximum memory and the model sizes.
|
||||
device_id_component_mapping = _assign_components_to_devices(
|
||||
module_sizes, max_memory, device_mapping_strategy=device_map
|
||||
)
|
||||
|
||||
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
|
||||
final_device_map = {}
|
||||
for device_id, components in device_id_component_mapping.items():
|
||||
for component in components:
|
||||
final_device_map[component] = device_id
|
||||
|
||||
return final_device_map
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
@@ -493,6 +701,22 @@ def load_sub_model(
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
|
||||
if needs_offloading_to_cpu:
|
||||
dispatch_model(
|
||||
loaded_sub_model,
|
||||
state_dict=loaded_sub_model.state_dict(),
|
||||
device_map=device_map,
|
||||
force_hooks=True,
|
||||
main_device=0,
|
||||
)
|
||||
else:
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ from .pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_pipeline_class,
|
||||
_get_final_device_map,
|
||||
_get_pipeline_class,
|
||||
_unwrap_model,
|
||||
is_safetensors_compatible,
|
||||
@@ -91,6 +92,8 @@ LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -141,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
model_cpu_offload_seq = None
|
||||
hf_device_map = None
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
@@ -389,6 +393,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
||||
@@ -642,18 +652,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
||||
)
|
||||
|
||||
if device_map is not None and not isinstance(device_map, str):
|
||||
raise ValueError("`device_map` must be a string.")
|
||||
|
||||
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
||||
)
|
||||
|
||||
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
||||
if is_accelerate_version("<", "0.28.0"):
|
||||
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
@@ -729,6 +756,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=custom_revision,
|
||||
)
|
||||
|
||||
if device_map is not None and pipeline_class._load_connected_pipes:
|
||||
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
@@ -795,17 +825,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
# 6. device map delegation
|
||||
final_device_map = None
|
||||
if device_map is not None:
|
||||
final_device_map = _get_final_device_map(
|
||||
device_map=device_map,
|
||||
pipeline_class=pipeline_class,
|
||||
passed_class_obj=passed_class_obj,
|
||||
init_dict=init_dict,
|
||||
library=library,
|
||||
max_memory=max_memory,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=cached_folder,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
component_device = final_device_map.get(name, None)
|
||||
if component_device is not None:
|
||||
current_device_map = {"": component_device}
|
||||
else:
|
||||
current_device_map = None
|
||||
|
||||
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
# 7.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
# 7.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
@@ -826,7 +884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
device_map=current_device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
@@ -893,7 +951,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 7. Potentially add passed objects if expected
|
||||
# 8. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
@@ -906,11 +964,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
# 10. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
# 11. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
return model
|
||||
|
||||
@property
|
||||
@@ -963,6 +1023,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
)
|
||||
|
||||
if self.model_cpu_offload_seq is None:
|
||||
raise ValueError(
|
||||
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
||||
@@ -1056,6 +1122,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
)
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
@@ -1090,6 +1162,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
offload_buffers = len(model._parameters) > 0
|
||||
cpu_offload(model, device, offload_buffers=offload_buffers)
|
||||
|
||||
def reset_device_map(self):
|
||||
r"""
|
||||
Resets the device maps (if any) to None.
|
||||
"""
|
||||
if self.hf_device_map is None:
|
||||
return
|
||||
else:
|
||||
self.remove_all_hooks()
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to("cpu")
|
||||
self.hf_device_map = None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
|
||||
@@ -255,6 +255,20 @@ def require_torch_accelerator(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
||||
multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
@@ -343,6 +357,18 @@ def require_peft_version_greater(peft_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_accelerate_version_greater(accelerate_version):
|
||||
def decorator(test_case):
|
||||
correct_accelerate_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
Reference in New Issue
Block a user