From fab17528daaf70b3fd440aa6903bfb15ce466f12 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Oct 2022 18:05:49 +0200 Subject: [PATCH] [Low CPU memory] + device map (#772) * add accelerate to load models with smaller memory footprint * remove low_cpu_mem_usage as it is reduntant * move accelerate init weights context to modelling utils * add test to ensure results are the same when loading with accelerate * add tests to ensure ram usage gets lower when using accelerate * move accelerate logic to single snippet under modelling utils and remove it from configuration utils * format code using to pass quality check * fix imports with isor * add accelerate to test extra deps * only import accelerate if device_map is set to auto * move accelerate availability check to diffusers import utils * format code * add device map to pipeline abstraction * lint it to pass PR quality check * fix class check to use accelerate when using diffusers ModelMixin subclasses * use low_cpu_mem_usage in transformers if device_map is not available * NoModuleLayer * comment out tests * up * uP * finish * Update src/diffusers/pipelines/stable_diffusion/safety_checker.py * finish * uP * make style Co-authored-by: Pi Esposito --- src/diffusers/pipeline_utils.py | 22 +++++++- .../stable_diffusion/safety_checker.py | 6 ++- tests/test_pipelines.py | 54 +++++++++++++++++++ 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 35996d6507..81118967aa 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -32,7 +32,19 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + ONNX_WEIGHTS_NAME, + WEIGHTS_NAME, + BaseOutput, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + from transformers import PreTrainedModel INDEX_FILE = "diffusion_pytorch_model.bin" @@ -338,6 +350,7 @@ class DiffusionPipeline(ConfigMixin): custom_pipeline = kwargs.pop("custom_pipeline", None) provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -463,6 +476,13 @@ class DiffusionPipeline(ConfigMixin): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options + if ( + issubclass(class_obj, diffusers.ModelMixin) + or is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + ): + loading_kwargs["device_map"] = device_map + # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 773a7d4b21..3984171f57 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds): class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig + _no_split_modules = ["CLIPEncoderLayer"] + def __init__(self, config: CLIPConfig): super().__init__(config) @@ -28,8 +30,8 @@ class StableDiffusionSafetyChecker(PreTrainedModel): self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) - self.register_buffer("concept_embeds_weights", torch.ones(17)) - self.register_buffer("special_care_embeds_weights", torch.ones(3)) + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4bf1e5e47c..30beb033fc 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -17,12 +17,15 @@ import gc import os import random import tempfile +import tracemalloc import unittest import numpy as np import torch +import accelerate import PIL +import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import get_tests_dir +from packaging import version from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -2034,3 +2038,53 @@ class PipelineTesterMixin(unittest.TestCase): pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) assert test_callback_fn.has_been_called assert number_of_steps == 6 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_accelerate_load_works(self): + if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): + return + + if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): + return + + model_id = "CompVis/stable-diffusion-v1-4" + _ = StableDiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" + ).to(torch_device) + + @slow + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): + if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): + return + + if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): + return + + pipeline_id = "CompVis/stable-diffusion-v1-4" + + torch.cuda.empty_cache() + gc.collect() + + tracemalloc.start() + pipeline_normal_load = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True + ) + pipeline_normal_load.to(torch_device) + _, peak_normal = tracemalloc.get_traced_memory() + tracemalloc.stop() + + del pipeline_normal_load + torch.cuda.empty_cache() + gc.collect() + + tracemalloc.start() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" + ) + _, peak_accelerate = tracemalloc.get_traced_memory() + + tracemalloc.stop() + + assert peak_accelerate < peak_normal