mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Accelerate model loading] Fix meta device and super low memory usage (#1016)
* [Accelerate model loading] Fix meta device and super low memory usage * better naming
This commit is contained in:
committed by
GitHub
parent
e92a603cab
commit
3be9fa97d6
@@ -119,14 +119,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def cuda_with_minimal_gpu_usage(self):
|
||||
def enable_sequential_cpu_offload(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
self.enable_attention_slicing(1)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -730,3 +731,39 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
||||
def test_stable_diffusion_accelerate_auto_device(self):
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
||||
)
|
||||
pipeline_normal_load.to(torch_device)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
)
|
||||
meta_device_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * meta_device_load_time < normal_load_time
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_attention_slicing(1)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
_ = pipeline(prompt, num_inference_steps=5)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 1.5 GB is allocated
|
||||
assert mem_bytes < 1.5 * 10**9
|
||||
|
||||
@@ -17,15 +17,12 @@ import gc
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import tracemalloc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import accelerate
|
||||
import PIL
|
||||
import transformers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
@@ -44,8 +41,7 @@ from diffusers import (
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from packaging import version
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -487,71 +483,3 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_accelerate_load_works(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
||||
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||
return
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
).to(torch_device)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
||||
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||
return
|
||||
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
||||
)
|
||||
pipeline_normal_load.to(torch_device)
|
||||
_, peak_normal = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
del pipeline_normal_load
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
)
|
||||
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||
|
||||
tracemalloc.stop()
|
||||
|
||||
assert peak_accelerate < peak_normal
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
|
||||
)
|
||||
pipeline.cuda_with_minimal_gpu_usage()
|
||||
|
||||
_ = pipeline(prompt)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 0.8 GB is allocated
|
||||
assert mem_bytes < 0.8 * 10**9
|
||||
|
||||
Reference in New Issue
Block a user