1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Low CPU memory] + device map (#772)

* add accelerate to load models with smaller memory footprint

* remove low_cpu_mem_usage as it is reduntant

* move accelerate init weights context to modelling utils

* add test to ensure results are the same when loading with accelerate

* add tests to ensure ram usage gets lower when using accelerate

* move accelerate logic to single snippet under modelling utils and remove it from configuration utils

* format code using to pass quality check

* fix imports with isor

* add accelerate to test extra deps

* only import accelerate if device_map is set to auto

* move accelerate availability check to diffusers import utils

* format code

* add device map to pipeline abstraction

* lint it to pass PR quality check

* fix class check to use accelerate when using diffusers ModelMixin subclasses

* use low_cpu_mem_usage in transformers if device_map is not available

* NoModuleLayer

* comment out tests

* up

* uP

* finish

* Update src/diffusers/pipelines/stable_diffusion/safety_checker.py

* finish

* uP

* make style

Co-authored-by: Pi Esposito <piero.skywalker@gmail.com>
This commit is contained in:
Patrick von Platen
2022-10-10 18:05:49 +02:00
committed by GitHub
parent feaa73243d
commit fab17528da
3 changed files with 79 additions and 3 deletions

View File

@@ -32,7 +32,19 @@ from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
is_transformers_available,
logging,
)
if is_transformers_available():
from transformers import PreTrainedModel
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -338,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -463,6 +476,13 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
if (
issubclass(class_obj, diffusers.ModelMixin)
or is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
):
loading_kwargs["device_map"] = device_map
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)

View File

@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
@@ -28,8 +30,8 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.register_buffer("concept_embeds_weights", torch.ones(17))
self.register_buffer("special_care_embeds_weights", torch.ones(3))
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):

View File

@@ -17,12 +17,15 @@ import gc
import os
import random
import tempfile
import tracemalloc
import unittest
import numpy as np
import torch
import accelerate
import PIL
import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -2034,3 +2038,53 @@ class PipelineTesterMixin(unittest.TestCase):
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called
assert number_of_steps == 6
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_works(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)
@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
pipeline_id = "CompVis/stable-diffusion-v1-4"
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()
del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()
tracemalloc.stop()
assert peak_accelerate < peak_normal