mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
minimal stable diffusion GPU memory usage with accelerate hooks (#850)
* add method to enable cuda with minimal gpu usage to stable diffusion * add test to minimal cuda memory usage * ensure all models but unet are onn torch.float32 * move to cpu_offload along with minor internal changes to make it work * make it test against accelerate master branch * coming back, its official: I don't know how to make it test againt the master branch from accelerate * make it install accelerate from master on tests * go back to accelerate>=0.11 * undo prettier formatting on yml files * undo prettier formatting on yml files againn
This commit is contained in:
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -34,6 +34,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -80,6 +81,7 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
|
||||
4
.github/workflows/push_tests.yml
vendored
4
.github/workflows/push_tests.yml
vendored
@@ -36,6 +36,7 @@ jobs:
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -58,8 +59,6 @@ jobs:
|
||||
name: torch_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
|
||||
run_examples_single_gpu:
|
||||
name: Examples tests
|
||||
runs-on: [ self-hosted, docker-gpu, single-gpu ]
|
||||
@@ -83,6 +82,7 @@ jobs:
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -223,6 +223,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.device == torch.device("meta"):
|
||||
return torch.device("cpu")
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -118,6 +119,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def cuda_with_minimal_gpu_usage(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
self.enable_attention_slicing(1)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -535,3 +535,23 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
tracemalloc.stop()
|
||||
|
||||
assert peak_accelerate < peak_normal
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
|
||||
)
|
||||
pipeline.cuda_with_minimal_gpu_usage()
|
||||
|
||||
_ = pipeline(prompt)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 0.8 GB is allocated
|
||||
assert mem_bytes < 0.8 * 10**9
|
||||
|
||||
Reference in New Issue
Block a user