mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[CPU offload] correct cpu offload (#1968)
* [CPU offload] correct cpu offload * [CPU offload] correct cpu offload * finish * finish * Update docs/source/en/optimization/fp16.mdx Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
50b6513531
commit
57f7d25934
@@ -149,7 +149,7 @@ You may see a small performance boost in VAE decode on multi-image batches. Ther
|
||||
|
||||
## Offloading to CPU with accelerate for memory savings
|
||||
|
||||
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.
|
||||
For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass.
|
||||
|
||||
To perform CPU offloading, all you have to do is invoke [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]:
|
||||
|
||||
@@ -162,16 +162,15 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
And you can get the memory consumption to < 2GB.
|
||||
And you can get the memory consumption to < 3GB.
|
||||
|
||||
If is also possible to chain it with attention slicing for minimal memory consumption, running it in as little as < 800mb of GPU vRAM:
|
||||
If is also possible to chain it with attention slicing for minimal memory consumption (< 2GB).
|
||||
|
||||
```Python
|
||||
import torch
|
||||
@@ -182,7 +181,6 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@@ -191,6 +189,8 @@ pipe.enable_attention_slicing(1)
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
**Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information.
|
||||
|
||||
## Using Channels Last memory format
|
||||
|
||||
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
|
||||
@@ -357,4 +357,4 @@ with torch.inference_mode():
|
||||
|
||||
# optional: You can disable it via
|
||||
# pipe.disable_xformers_memory_efficient_attention()
|
||||
```
|
||||
```
|
||||
|
||||
@@ -211,13 +211,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
|
||||
@@ -233,13 +233,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
|
||||
@@ -236,13 +236,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
|
||||
@@ -208,13 +208,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
|
||||
@@ -238,13 +238,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
|
||||
@@ -272,13 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
|
||||
@@ -205,13 +205,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
|
||||
@@ -137,13 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
|
||||
Reference in New Issue
Block a user