mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
enable_model_cpu_offload (#2285)
* enable_model_offload PoC It's surprisingly more involved than expected, see comments in the PR. * Rename final_offload_hook * Invoke the vae forward hook manually. * Completely remove decoder. * Style * apply_forward_hook decorator * Rename method. * Style * Copy enable_model_cpu_offload * Fix copies. * Remove comment. * Fix copies * Missing import * Fix doc-builder style. * Merge main and fix again. * Add docs * Fix docs. * Add a couple of tests. * style
This commit is contained in:
@@ -133,6 +133,7 @@ images = pipe([prompt] * 32).images
|
||||
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
|
||||
|
||||
|
||||
<a name="sequential_offloading"></a>
|
||||
## Offloading to CPU with accelerate for memory savings
|
||||
|
||||
For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass.
|
||||
@@ -156,7 +157,13 @@ image = pipe(prompt).images[0]
|
||||
|
||||
And you can get the memory consumption to < 3GB.
|
||||
|
||||
If is also possible to chain it with attention slicing for minimal memory consumption (< 2GB).
|
||||
Note that this method works at the submodule level, not on whole models. This is the best way to minimize memory consumption, but inference is much slower due to the iterative nature of the process. The UNet component of the pipeline runs several times (as many as `num_inference_steps`); each time, the different submodules of the UNet are sequentially onloaded and then offloaded as they are needed, so the number of memory transfers is large.
|
||||
|
||||
<Tip>
|
||||
Consider using <a href="#model_offloading">model offloading</a> as another point in the optimization space: it will be much faster, but memory savings won't be as large.
|
||||
</Tip>
|
||||
|
||||
It is also possible to chain offloading with attention slicing for minimal memory consumption (< 2GB).
|
||||
|
||||
```Python
|
||||
import torch
|
||||
@@ -177,6 +184,55 @@ image = pipe(prompt).images[0]
|
||||
|
||||
**Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information.
|
||||
|
||||
|
||||
<a name="model_offloading"></a>
|
||||
## Model offloading for fast inference and memory savings
|
||||
|
||||
[Sequential CPU offloading](#sequential_offloading), as discussed in the previous section, preserves a lot of memory but makes inference slower, because submodules are moved to GPU as needed, and immediately returned to CPU when a new module runs.
|
||||
|
||||
Full-model offloading is an alternative that moves whole models to the GPU, instead of handling each model's constituent _modules_. This results in a negligible impact on inference time (compared with moving the pipeline to `cuda`), while still providing some memory savings.
|
||||
|
||||
In this scenario, only one of the main components of the pipeline (typically: text encoder, unet and vae)
|
||||
will be in the GPU while the others wait in the CPU. Compoments like the UNet that run for multiple iterations will stay on GPU until they are no longer needed.
|
||||
|
||||
This feature can be enabled by invoking `enable_model_cpu_offload()` on the pipeline, as shown below.
|
||||
|
||||
```Python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
This is also compatible with attention slicing for additional memory savings.
|
||||
|
||||
```Python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing(1)
|
||||
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
This feature requires `accelerate` version 0.17.0 or larger.
|
||||
</Tip>
|
||||
|
||||
## Using Channels Last memory format
|
||||
|
||||
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -109,6 +109,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
@@ -144,6 +145,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -188,6 +188,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -202,6 +204,30 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
@@ -209,7 +235,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -651,6 +677,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -210,6 +210,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -224,6 +226,30 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
@@ -231,7 +257,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -684,6 +710,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -215,7 +215,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -227,6 +227,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -241,6 +243,31 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -249,7 +276,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -22,7 +22,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -184,6 +191,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -198,6 +207,30 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
@@ -205,7 +238,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -647,6 +680,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -145,7 +145,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -28,6 +28,7 @@ from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
@@ -214,6 +215,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -228,6 +231,31 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -236,7 +264,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -693,6 +721,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -262,6 +262,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -276,6 +278,31 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -284,7 +311,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -845,6 +872,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -24,7 +24,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -199,6 +206,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -213,6 +222,31 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -221,7 +255,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -666,6 +700,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -22,7 +22,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -373,6 +380,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -384,6 +395,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -398,6 +411,31 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -406,7 +444,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -20,7 +20,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
@@ -127,6 +127,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -141,6 +143,31 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -149,7 +176,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
@@ -502,6 +529,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -252,6 +252,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -274,7 +276,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -166,6 +166,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
@@ -188,7 +190,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -139,7 +139,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -186,7 +186,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -188,7 +188,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -192,7 +192,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
|
||||
@@ -170,7 +170,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
|
||||
if not hasattr(self.image_unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_unet.modules():
|
||||
if (
|
||||
|
||||
@@ -97,7 +97,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
|
||||
if not hasattr(self.image_unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_unet.modules():
|
||||
if (
|
||||
|
||||
@@ -121,7 +121,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
|
||||
if not hasattr(self.image_unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_unet.modules():
|
||||
if (
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .accelerate_utils import apply_forward_hook
|
||||
from .constants import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
@@ -44,6 +45,7 @@ from .import_utils import (
|
||||
DummyObject,
|
||||
OptionalDependencyNotAvailable,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_k_diffusion_available,
|
||||
|
||||
46
src/diffusers/utils/accelerate_utils.py
Normal file
46
src/diffusers/utils/accelerate_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Accelerate utilities: Utilities related to accelerate
|
||||
"""
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
|
||||
|
||||
def apply_forward_hook(method):
|
||||
"""
|
||||
Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful
|
||||
for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the
|
||||
appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`].
|
||||
|
||||
This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.
|
||||
|
||||
:param method: The method to decorate. This method should be a method of a PyTorch module.
|
||||
"""
|
||||
accelerate_version = version.parse(accelerate.__version__).base_version
|
||||
if version.parse(accelerate_version) < version.parse("0.17.0"):
|
||||
return method
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):
|
||||
self._hf_hook.pre_forward(self)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -476,6 +476,20 @@ def is_transformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
|
||||
def is_accelerate_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _accelerate_available:
|
||||
return False
|
||||
return compare_versions(parse(_accelerate_version), operation, version)
|
||||
|
||||
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -789,6 +789,59 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.8 GB is allocated
|
||||
assert mem_bytes < 2.8 * 10**9
|
||||
|
||||
def test_stable_diffusion_pipeline_with_model_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
|
||||
# Normal inference
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
outputs = pipe(**inputs)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
# With model offloading
|
||||
|
||||
# Reload but don't move to cuda
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
outputs_offloaded = pipe(**inputs)
|
||||
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3
|
||||
assert mem_bytes_offloaded < mem_bytes
|
||||
assert mem_bytes_offloaded < 3.5 * 10**9
|
||||
for module in pipe.text_encoder, pipe.unet, pipe.vae, pipe.safety_checker:
|
||||
assert module.device == torch.device("cpu")
|
||||
|
||||
# With attention slicing
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe.enable_attention_slicing()
|
||||
_ = pipe(**inputs)
|
||||
mem_bytes_slicing = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes_slicing < mem_bytes_offloaded
|
||||
assert mem_bytes_slicing < 3 * 10**9
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -342,6 +342,47 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.2 GB is allocated
|
||||
assert mem_bytes < 2.2 * 10**9
|
||||
|
||||
def test_stable_diffusion_pipeline_with_model_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
|
||||
# Normal inference
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe(**inputs)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
# With model offloading
|
||||
|
||||
# Reload but don't move to cuda
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_ = pipe(**inputs)
|
||||
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes_offloaded < mem_bytes
|
||||
for module in pipe.text_encoder, pipe.unet, pipe.vae:
|
||||
assert module.device == torch.device("cpu")
|
||||
|
||||
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
|
||||
@@ -393,6 +393,57 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.8 GB is allocated
|
||||
assert mem_bytes < 2.8 * 10**9
|
||||
|
||||
def test_stable_diffusion_pipeline_with_model_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
|
||||
# Normal inference
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-base",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
outputs = pipe(**inputs)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
# With model offloading
|
||||
|
||||
# Reload but don't move to cuda
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-base",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
outputs_offloaded = pipe(**inputs)
|
||||
mem_bytes_offloaded = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3
|
||||
assert mem_bytes_offloaded < mem_bytes
|
||||
assert mem_bytes_offloaded < 3 * 10**9
|
||||
for module in pipe.text_encoder, pipe.unet, pipe.vae:
|
||||
assert module.device == torch.device("cpu")
|
||||
|
||||
# With attention slicing
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe.enable_attention_slicing()
|
||||
_ = pipe(**inputs)
|
||||
mem_bytes_slicing = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes_slicing < mem_bytes_offloaded
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user