diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index ca24556871..a7115c3ea5 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -133,6 +133,7 @@ images = pipe([prompt] * 32).images You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. + ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass. @@ -156,7 +157,13 @@ image = pipe(prompt).images[0] And you can get the memory consumption to < 3GB. -If is also possible to chain it with attention slicing for minimal memory consumption (< 2GB). +Note that this method works at the submodule level, not on whole models. This is the best way to minimize memory consumption, but inference is much slower due to the iterative nature of the process. The UNet component of the pipeline runs several times (as many as `num_inference_steps`); each time, the different submodules of the UNet are sequentially onloaded and then offloaded as they are needed, so the number of memory transfers is large. + + +Consider using model offloading as another point in the optimization space: it will be much faster, but memory savings won't be as large. + + +It is also possible to chain offloading with attention slicing for minimal memory consumption (< 2GB). ```Python import torch @@ -177,6 +184,55 @@ image = pipe(prompt).images[0] **Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information. + + +## Model offloading for fast inference and memory savings + +[Sequential CPU offloading](#sequential_offloading), as discussed in the previous section, preserves a lot of memory but makes inference slower, because submodules are moved to GPU as needed, and immediately returned to CPU when a new module runs. + +Full-model offloading is an alternative that moves whole models to the GPU, instead of handling each model's constituent _modules_. This results in a negligible impact on inference time (compared with moving the pipeline to `cuda`), while still providing some memory savings. + +In this scenario, only one of the main components of the pipeline (typically: text encoder, unet and vae) +will be in the GPU while the others wait in the CPU. Compoments like the UNet that run for multiple iterations will stay on GPU until they are no longer needed. + +This feature can be enabled by invoking `enable_model_cpu_offload()` on the pipeline, as shown below. + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, +) + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_model_cpu_offload() +image = pipe(prompt).images[0] +``` + +This is also compatible with attention slicing for additional memory savings. + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, +) + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_model_cpu_offload() +pipe.enable_attention_slicing(1) + +image = pipe(prompt).images[0] +``` + + +This feature requires `accelerate` version 0.17.0 or larger. + + ## Using Channels Last memory format Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model. diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 18e7eb39c8..21d77d902e 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -109,6 +109,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False + @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) moments = self.quant_conv(h) @@ -144,6 +145,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): """ self.use_slicing = False + @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 392e6b3956..7ff2140f15 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -19,7 +19,7 @@ import torch from packaging import version from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer -from diffusers.utils import is_accelerate_available +from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel @@ -188,6 +188,8 @@ class AltDiffusionPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -202,6 +204,30 @@ class AltDiffusionPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property def _execution_device(self): r""" @@ -209,7 +235,7 @@ class AltDiffusionPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -651,6 +677,10 @@ class AltDiffusionPipeline(DiffusionPipeline): # 9. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 828c133d67..ab37d68682 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -21,7 +21,7 @@ import torch from packaging import version from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer -from diffusers.utils import is_accelerate_available +from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel @@ -210,6 +210,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -224,6 +226,30 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property def _execution_device(self): r""" @@ -231,7 +257,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -684,6 +710,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 693166d106..45f8868fad 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -215,7 +215,7 @@ class PaintByExamplePipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index f892fbe6db..7932bd0bd6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -21,7 +21,7 @@ import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from diffusers.utils import is_accelerate_available +from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel @@ -227,6 +227,8 @@ class CycleDiffusionPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -241,6 +243,31 @@ class CycleDiffusionPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -249,7 +276,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c3bfa334ca..f6e5e7d2dd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -184,6 +191,8 @@ class StableDiffusionPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -198,6 +207,30 @@ class StableDiffusionPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property def _execution_device(self): r""" @@ -205,7 +238,7 @@ class StableDiffusionPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -647,6 +680,10 @@ class StableDiffusionPipeline(DiffusionPipeline): # 9. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 0d90e0755e..879f374af7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -146,7 +146,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 026ab441d5..e7419114ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -145,7 +145,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index cffedf5456..3a17f094b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -28,6 +28,7 @@ from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, + is_accelerate_version, logging, randn_tensor, replace_example_docstring, @@ -214,6 +215,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -228,6 +231,31 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -236,7 +264,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -693,6 +721,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d10bded530..f66b7d3a0b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor +from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -262,6 +262,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -276,6 +278,31 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -284,7 +311,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -845,6 +872,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index bad552b710..e71d8e593e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -24,7 +24,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -199,6 +206,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -213,6 +222,31 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -221,7 +255,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -666,6 +700,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d47e69cc44..a7b73981e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -22,7 +22,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -373,6 +380,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) @@ -384,6 +395,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -398,6 +411,31 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -406,7 +444,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index c3055d8a16..7cf3c7bf58 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -20,7 +20,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler -from ...utils import is_accelerate_available, logging, randn_tensor +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from . import StableDiffusionPipelineOutput @@ -127,6 +127,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -141,6 +143,31 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -149,7 +176,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -502,6 +529,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 361304116a..624d0e6258 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -117,7 +117,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index f0f09b8739..af0a7e2a4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -252,6 +252,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -274,7 +276,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index e49c48c3c7..0c1811fc64 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -166,6 +166,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload @@ -188,7 +190,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 5bc268f8c3..b6c0b591ec 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -139,7 +139,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 59a4902ec5..d4eabd70ac 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -186,7 +186,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index d4f85ce7c6..8fdc47a786 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -188,7 +188,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index cb555cd0a3..3d0ddce715 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -192,7 +192,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index e3d0db540b..849e25675d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -170,7 +170,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 815488d7bc..d676b87b96 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -97,7 +97,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 87e3f39822..9e983938fd 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -121,7 +121,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 822fa9cf6f..429c4e39de 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -18,6 +18,7 @@ import os from packaging import version from .. import __version__ +from .accelerate_utils import apply_forward_hook from .constants import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, @@ -44,6 +45,7 @@ from .import_utils import ( DummyObject, OptionalDependencyNotAvailable, is_accelerate_available, + is_accelerate_version, is_flax_available, is_inflect_available, is_k_diffusion_available, diff --git a/src/diffusers/utils/accelerate_utils.py b/src/diffusers/utils/accelerate_utils.py new file mode 100644 index 0000000000..a92d475476 --- /dev/null +++ b/src/diffusers/utils/accelerate_utils.py @@ -0,0 +1,46 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Accelerate utilities: Utilities related to accelerate +""" + +from packaging import version + +from .import_utils import is_accelerate_available + + +if is_accelerate_available(): + import accelerate + + +def apply_forward_hook(method): + """ + Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful + for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the + appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. + + This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. + + :param method: The method to decorate. This method should be a method of a PyTorch module. + """ + accelerate_version = version.parse(accelerate.__version__).base_version + if version.parse(accelerate_version) < version.parse("0.17.0"): + return method + + def wrapper(self, *args, **kwargs): + if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): + self._hf_hook.pre_forward(self) + return method(self, *args, **kwargs) + + return wrapper diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 7afa1af2e7..925e431f55 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -476,6 +476,20 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +def is_accelerate_version(operation: str, version: str): + """ + Args: + Compares the current Accelerate version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _accelerate_available: + return False + return compare_versions(parse(_accelerate_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Args: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 02774d69dc..75e2a44f01 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -789,6 +789,59 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): # make sure that less than 2.8 GB is allocated assert mem_bytes < 2.8 * 10**9 + def test_stable_diffusion_pipeline_with_model_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + inputs = self.get_inputs(torch_device, dtype=torch.float16) + + # Normal inference + + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + outputs = pipe(**inputs) + mem_bytes = torch.cuda.max_memory_allocated() + + # With model offloading + + # Reload but don't move to cuda + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + torch_dtype=torch.float16, + ) + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + outputs_offloaded = pipe(**inputs) + mem_bytes_offloaded = torch.cuda.max_memory_allocated() + + assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3 + assert mem_bytes_offloaded < mem_bytes + assert mem_bytes_offloaded < 3.5 * 10**9 + for module in pipe.text_encoder, pipe.unet, pipe.vae, pipe.safety_checker: + assert module.device == torch.device("cpu") + + # With attention slicing + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe.enable_attention_slicing() + _ = pipe(**inputs) + mem_bytes_slicing = torch.cuda.max_memory_allocated() + + assert mem_bytes_slicing < mem_bytes_offloaded + assert mem_bytes_slicing < 3 * 10**9 + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index b162fe3ac6..e10b608734 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -342,6 +342,47 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase): # make sure that less than 2.2 GB is allocated assert mem_bytes < 2.2 * 10**9 + def test_stable_diffusion_pipeline_with_model_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + inputs = self.get_inputs(torch_device, dtype=torch.float16) + + # Normal inference + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + safety_checker=None, + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe(**inputs) + mem_bytes = torch.cuda.max_memory_allocated() + + # With model offloading + + # Reload but don't move to cuda + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + safety_checker=None, + torch_dtype=torch.float16, + ) + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + _ = pipe(**inputs) + mem_bytes_offloaded = torch.cuda.max_memory_allocated() + + assert mem_bytes_offloaded < mem_bytes + for module in pipe.text_encoder, pipe.unet, pipe.vae: + assert module.device == torch.device("cpu") + def test_stable_diffusion_img2img_pipeline_multiple_of_8(self): init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 9c9a0a0186..4657bf2744 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -393,6 +393,57 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): # make sure that less than 2.8 GB is allocated assert mem_bytes < 2.8 * 10**9 + def test_stable_diffusion_pipeline_with_model_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + inputs = self.get_inputs(torch_device, dtype=torch.float16) + + # Normal inference + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-base", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + outputs = pipe(**inputs) + mem_bytes = torch.cuda.max_memory_allocated() + + # With model offloading + + # Reload but don't move to cuda + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-base", + torch_dtype=torch.float16, + ) + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + outputs_offloaded = pipe(**inputs) + mem_bytes_offloaded = torch.cuda.max_memory_allocated() + + assert np.abs(outputs.images - outputs_offloaded.images).max() < 1e-3 + assert mem_bytes_offloaded < mem_bytes + assert mem_bytes_offloaded < 3 * 10**9 + for module in pipe.text_encoder, pipe.unet, pipe.vae: + assert module.device == torch.device("cpu") + + # With attention slicing + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe.enable_attention_slicing() + _ = pipe(**inputs) + mem_bytes_slicing = torch.cuda.max_memory_allocated() + assert mem_bytes_slicing < mem_bytes_offloaded + @nightly @require_torch_gpu