mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] add "balanced" device_map support to pipelines (#6857)
* get device <-> component mapping when using multiple gpus. * condition the device_map bits. * relax condition * device_map progress. * device_map enhancement * some cleaning up and debugging * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * incorporate suggestions from PR. * remove multi-gpu condition for now. * guard check the component -> device mapping * fix: device_memory variable * dispatching transformers model to have force_hooks=True * better guarding for transformers device_map * introduce support balanced_low_memory and balanced_ultra_low_memory. * remove device_map patch. * fix: intermediate variable scoping. * fix: condition in cpu offload. * fix: flax class restrictions. * remove modifications from cpu_offload and model_offload * incorporate changes. * add a simple forward pass test * add: torch_device in get_inputs() * add: tests * remove print * safe-guard to(), model offloading and cpu offloading when balanced is used as a device_map. * style * remove . * safeguard device_map with more checks and remove invalid device_mapping strategues. * make a class attribute and adjust tests accordingly. * fix device_map check * fix test * adjust comment * fix: device_map attribute * fix: dispatching. * max_memory test for pipeline * version guard the tests * fix guard. * address review feedback. * reset_device_map method. * add: test for reset_hf_device_map * fix a couple things. * add reset_device_map() in the error message. * add tests for checking reset_device_map doesn't have unintended consequences. * fix reset_device_map and offloading tests. * create _get_final_device_map utility. * hf_device_map -> _hf_device_map * add documentation * add notes suggested by Marc. * styling. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * move updates within gpu condition. * other docs related things * note on ignore a device not specified in . * provide a suggestion if device mapping errors out. * fix: typo. * _hf_device_map -> hf_device_map * Empty-Commit * add: example hf_device_map. --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -52,6 +52,79 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
|
||||
|
||||
</Tip>
|
||||
|
||||
### Device placement
|
||||
|
||||
> [!WARNING]
|
||||
> This feature is experimental and its APIs might change in the future.
|
||||
|
||||
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
|
||||
|
||||
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
|
||||
|
||||
* it only works on a single GPU
|
||||
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
|
||||
|
||||
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
|
||||
|
||||
> [!TIP]
|
||||
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
|
||||
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Currently, we support only "balanced" `device_map`. We plan to support more device mapping strategies in future.
|
||||
|
||||
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
max_memory = {0:"1GB", 1:"1GB"}
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
device_map="balanced",
|
||||
+ max_memory=max_memory
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
|
||||
|
||||
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
|
||||
|
||||
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
|
||||
|
||||
```py
|
||||
pipeline.reset_device_map()
|
||||
```
|
||||
|
||||
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
|
||||
|
||||
```py
|
||||
print(pipeline.hf_device_map)
|
||||
```
|
||||
|
||||
An example device map would look like so:
|
||||
|
||||
|
||||
```bash
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
|
||||
## PyTorch Distributed
|
||||
|
||||
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
|
||||
|
||||
@@ -699,6 +699,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=True,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
|
||||
@@ -22,15 +22,19 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
model_info,
|
||||
)
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -44,9 +48,12 @@ if is_transformers_available():
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
from accelerate.utils import compute_module_sizes, get_max_memory
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
@@ -376,6 +383,207 @@ def _get_pipeline_class(
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def _load_empty_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
name: str,
|
||||
torch_dtype: Union[str, torch.dtype],
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve class objects.
|
||||
class_obj, _ = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
# Determine library.
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
model = None
|
||||
config_path = cached_folder
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if is_diffusers_model:
|
||||
# Load config and then the model on meta.
|
||||
config, unused_kwargs, commit_hash = class_obj.load_config(
|
||||
os.path.join(config_path, name),
|
||||
cache_dir=cached_folder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj.from_config(config, **unused_kwargs)
|
||||
elif is_transformers_model:
|
||||
config_class = getattr(class_obj, "config_class", None)
|
||||
if config_class is None:
|
||||
raise ValueError("`config_class` cannot be None. Please double-check the model.")
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
cached_folder,
|
||||
subfolder=name,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj(config)
|
||||
|
||||
if model is not None:
|
||||
model = model.to(dtype=torch_dtype)
|
||||
return model
|
||||
|
||||
|
||||
def _assign_components_to_devices(
|
||||
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
|
||||
):
|
||||
device_ids = list(device_memory.keys())
|
||||
device_cycle = device_ids + device_ids[::-1]
|
||||
device_memory = device_memory.copy()
|
||||
|
||||
device_id_component_mapping = {}
|
||||
current_device_index = 0
|
||||
for component in module_sizes:
|
||||
device_id = device_cycle[current_device_index % len(device_cycle)]
|
||||
component_memory = module_sizes[component]
|
||||
curr_device_memory = device_memory[device_id]
|
||||
|
||||
# If the GPU doesn't fit the current component offload to the CPU.
|
||||
if component_memory > curr_device_memory:
|
||||
device_id_component_mapping["cpu"] = [component]
|
||||
else:
|
||||
if device_id not in device_id_component_mapping:
|
||||
device_id_component_mapping[device_id] = [component]
|
||||
else:
|
||||
device_id_component_mapping[device_id].append(component)
|
||||
|
||||
# Update the device memory.
|
||||
device_memory[device_id] -= component_memory
|
||||
current_device_index += 1
|
||||
|
||||
return device_id_component_mapping
|
||||
|
||||
|
||||
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
|
||||
# To avoid circular import problem.
|
||||
from diffusers import pipelines
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
||||
|
||||
# Load each module in the pipeline on a meta device so that we can derive the device map.
|
||||
init_empty_modules = {}
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name.startswith("Flax"):
|
||||
raise ValueError("Flax pipelines are not supported with `device_map`.")
|
||||
|
||||
# Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name,
|
||||
library,
|
||||
class_name,
|
||||
importable_classes,
|
||||
passed_class_obj,
|
||||
name,
|
||||
is_pipeline_module,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
loaded_sub_model = _load_empty_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
name=name,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=kwargs.get("cached_folder", None),
|
||||
force_download=kwargs.get("force_download", None),
|
||||
resume_download=kwargs.get("resume_download", None),
|
||||
proxies=kwargs.get("proxies", None),
|
||||
local_files_only=kwargs.get("local_files_only", None),
|
||||
token=kwargs.get("token", None),
|
||||
revision=kwargs.get("revision", None),
|
||||
)
|
||||
|
||||
if loaded_sub_model is not None:
|
||||
init_empty_modules[name] = loaded_sub_model
|
||||
|
||||
# determine device map
|
||||
# Obtain a sorted dictionary for mapping the model-level components
|
||||
# to their sizes.
|
||||
module_sizes = {
|
||||
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
|
||||
for module_name, module in init_empty_modules.items()
|
||||
if isinstance(module, torch.nn.Module)
|
||||
}
|
||||
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Obtain maximum memory available per device (GPUs only).
|
||||
max_memory = get_max_memory(max_memory)
|
||||
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
|
||||
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
|
||||
|
||||
# Obtain a dictionary mapping the model-level components to the available
|
||||
# devices based on the maximum memory and the model sizes.
|
||||
device_id_component_mapping = _assign_components_to_devices(
|
||||
module_sizes, max_memory, device_mapping_strategy=device_map
|
||||
)
|
||||
|
||||
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
|
||||
final_device_map = {}
|
||||
for device_id, components in device_id_component_mapping.items():
|
||||
for component in components:
|
||||
final_device_map[component] = device_id
|
||||
|
||||
return final_device_map
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
@@ -493,6 +701,22 @@ def load_sub_model(
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
|
||||
if needs_offloading_to_cpu:
|
||||
dispatch_model(
|
||||
loaded_sub_model,
|
||||
state_dict=loaded_sub_model.state_dict(),
|
||||
device_map=device_map,
|
||||
force_hooks=True,
|
||||
main_device=0,
|
||||
)
|
||||
else:
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ from .pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_pipeline_class,
|
||||
_get_final_device_map,
|
||||
_get_pipeline_class,
|
||||
_unwrap_model,
|
||||
is_safetensors_compatible,
|
||||
@@ -91,6 +92,8 @@ LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -141,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
model_cpu_offload_seq = None
|
||||
hf_device_map = None
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
@@ -389,6 +393,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
||||
@@ -642,18 +652,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
||||
)
|
||||
|
||||
if device_map is not None and not isinstance(device_map, str):
|
||||
raise ValueError("`device_map` must be a string.")
|
||||
|
||||
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
||||
)
|
||||
|
||||
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
||||
if is_accelerate_version("<", "0.28.0"):
|
||||
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
@@ -729,6 +756,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=custom_revision,
|
||||
)
|
||||
|
||||
if device_map is not None and pipeline_class._load_connected_pipes:
|
||||
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
@@ -795,17 +825,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
# 6. device map delegation
|
||||
final_device_map = None
|
||||
if device_map is not None:
|
||||
final_device_map = _get_final_device_map(
|
||||
device_map=device_map,
|
||||
pipeline_class=pipeline_class,
|
||||
passed_class_obj=passed_class_obj,
|
||||
init_dict=init_dict,
|
||||
library=library,
|
||||
max_memory=max_memory,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=cached_folder,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
component_device = final_device_map.get(name, None)
|
||||
if component_device is not None:
|
||||
current_device_map = {"": component_device}
|
||||
else:
|
||||
current_device_map = None
|
||||
|
||||
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
# 7.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
# 7.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
@@ -826,7 +884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
device_map=current_device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
@@ -893,7 +951,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 7. Potentially add passed objects if expected
|
||||
# 8. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
@@ -906,11 +964,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
# 10. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
# 11. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
return model
|
||||
|
||||
@property
|
||||
@@ -963,6 +1023,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
)
|
||||
|
||||
if self.model_cpu_offload_seq is None:
|
||||
raise ValueError(
|
||||
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
||||
@@ -1056,6 +1122,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
)
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
@@ -1090,6 +1162,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
offload_buffers = len(model._parameters) > 0
|
||||
cpu_offload(model, device, offload_buffers=offload_buffers)
|
||||
|
||||
def reset_device_map(self):
|
||||
r"""
|
||||
Resets the device maps (if any) to None.
|
||||
"""
|
||||
if self.hf_device_map is None:
|
||||
return
|
||||
else:
|
||||
self.remove_all_hooks()
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to("cpu")
|
||||
self.hf_device_map = None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
|
||||
@@ -255,6 +255,20 @@ def require_torch_accelerator(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
||||
multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
@@ -343,6 +357,18 @@ def require_peft_version_greater(peft_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_accelerate_version_greater(accelerate_version):
|
||||
def decorator(test_case):
|
||||
correct_accelerate_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -124,7 +124,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])
|
||||
|
||||
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
|
||||
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution")
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
@@ -50,9 +50,11 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate_version_greater,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -1442,3 +1444,121 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
@require_accelerate_version_greater("0.27.0")
|
||||
class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, generator_device="cpu", seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_pipeline_output_without_device_map(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
no_device_map_image = sd_pipe(**inputs).images
|
||||
|
||||
del sd_pipe
|
||||
|
||||
return no_device_map_image
|
||||
|
||||
def test_forward_pass_balanced_device_map(self):
|
||||
no_device_map_image = self.get_pipeline_output_without_device_map()
|
||||
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
device_map_image = sd_pipe_with_device_map(**inputs).images
|
||||
|
||||
max_diff = np.abs(device_map_image - no_device_map_image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_components_put_in_right_devices(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
assert len(set(sd_pipe_with_device_map.hf_device_map.values())) >= 2
|
||||
|
||||
def test_max_memory(self):
|
||||
no_device_map_image = self.get_pipeline_output_without_device_map()
|
||||
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
device_map="balanced",
|
||||
max_memory={0: "1GB", 1: "1GB"},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
sd_pipe_with_device_map.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
device_map_image = sd_pipe_with_device_map(**inputs).images
|
||||
|
||||
max_diff = np.abs(device_map_image - no_device_map_image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_reset_device_map(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
for name, component in sd_pipe_with_device_map.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
assert component.device.type == "cpu"
|
||||
|
||||
def test_reset_device_map_to(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `to()` can be used and the pipeline can be called.
|
||||
pipe = sd_pipe_with_device_map.to("cuda")
|
||||
_ = pipe("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_model_cpu_offload(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload()
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_sequential_cpu_offload(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload()
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
Reference in New Issue
Block a user