From b2e2d1411ce394ef15c41aafb34b3c08beedff0f Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 26 Oct 2022 10:52:57 -0300 Subject: [PATCH] minimal stable diffusion GPU memory usage with accelerate hooks (#850) * add method to enable cuda with minimal gpu usage to stable diffusion * add test to minimal cuda memory usage * ensure all models but unet are onn torch.float32 * move to cpu_offload along with minor internal changes to make it work * make it test against accelerate master branch * coming back, its official: I don't know how to make it test againt the master branch from accelerate * make it install accelerate from master on tests * go back to accelerate>=0.11 * undo prettier formatting on yml files * undo prettier formatting on yml files againn --- .github/workflows/pr_tests.yml | 2 ++ .github/workflows/push_tests.yml | 4 ++-- src/diffusers/pipeline_utils.py | 2 ++ .../pipeline_stable_diffusion.py | 13 ++++++++++++ tests/test_pipelines.py | 20 +++++++++++++++++++ 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index cf21edf991..81c75fecec 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -34,6 +34,7 @@ jobs: python -m pip install --upgrade pip python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -80,6 +81,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - name: Environment shell: arch -arch arm64 bash {0} diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 3e4a81c91c..dfd83aa9af 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -36,6 +36,7 @@ jobs: python -m pip uninstall -y torch torchvision torchtext python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -58,8 +59,6 @@ jobs: name: torch_test_reports path: reports - - run_examples_single_gpu: name: Examples tests runs-on: [ self-hosted, docker-gpu, single-gpu ] @@ -83,6 +82,7 @@ jobs: python -m pip uninstall -y torch torchvision torchtext python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 python -m pip install -e .[quality,test,training] + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d307bd5a07..c9c58a7488 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -223,6 +223,8 @@ class DiffusionPipeline(ConfigMixin): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): + if module.device == torch.device("meta"): + return torch.device("cpu") return module.device return torch.device("cpu") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 02a6b45fde..cf4c5c5fde 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union import torch +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -118,6 +119,18 @@ class StableDiffusionPipeline(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def cuda_with_minimal_gpu_usage(self): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + self.enable_attention_slicing(1) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + cpu_offload(cpu_offloaded_model, device) + @torch.no_grad() def __call__( self, diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a4686366c8..6e9388ca3a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -535,3 +535,23 @@ class PipelineSlowTests(unittest.TestCase): tracemalloc.stop() assert peak_accelerate < peak_normal + + @slow + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + pipeline_id = "CompVis/stable-diffusion-v1-4" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True + ) + pipeline.cuda_with_minimal_gpu_usage() + + _ = pipeline(prompt) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 0.8 GB is allocated + assert mem_bytes < 0.8 * 10**9