1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components (#9840)

* allow device placement when using bnb quantization.

* warning.

* tests

* fixes

* docs.

* require accelerate version.

* remove print.

* revert to()

* tests

* fixes

* fix: missing AutoencoderKL lora adapter (#9807)

* fix: missing AutoencoderKL lora adapter

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fixes

* fix condition test

* updates

* updates

* remove is_offloaded.

* fixes

* better

* empty

---------

Co-authored-by: Emmanuel Benazera <emmanuel.benazera@jolibrain.com>
This commit is contained in:
Sayak Paul
2024-12-04 22:27:43 +05:30
committed by GitHub
parent 8a450c3da0
commit e8da75dff5
3 changed files with 91 additions and 6 deletions

View File

@@ -66,7 +66,6 @@ from ..utils.torch_utils import is_compiled_module
if is_torch_npu_available():
import torch_npu # noqa: F401
from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS,
@@ -388,6 +387,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
device = device or device_arg
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
@@ -410,10 +410,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:

View File

@@ -18,10 +18,11 @@ import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
@@ -47,6 +48,7 @@ def get_some_linear_layer(model):
if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel
if is_torch_available():
@@ -483,6 +485,47 @@ class SlowBnb4BitTests(Base4bitTests):
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
transformer_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")
# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):

View File

@@ -17,8 +17,10 @@ import tempfile
import unittest
import numpy as np
import pytest
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
@@ -44,6 +46,7 @@ def get_some_linear_layer(model):
if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel
if is_torch_available():
@@ -432,6 +435,39 @@ class SlowBnb8bitTests(Base8bitTests):
output_type="np",
).images
@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
).to("cuda")
# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):