mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components (#9840)
* allow device placement when using bnb quantization. * warning. * tests * fixes * docs. * require accelerate version. * remove print. * revert to() * tests * fixes * fix: missing AutoencoderKL lora adapter (#9807) * fix: missing AutoencoderKL lora adapter * fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * fixes * fix condition test * updates * updates * remove is_offloaded. * fixes * better * empty --------- Co-authored-by: Emmanuel Benazera <emmanuel.benazera@jolibrain.com>
This commit is contained in:
@@ -66,7 +66,6 @@ from ..utils.torch_utils import is_compiled_module
|
||||
if is_torch_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
|
||||
from .pipeline_loading_utils import (
|
||||
ALL_IMPORTABLE_CLASSES,
|
||||
CONNECTED_PIPES_KEYS,
|
||||
@@ -388,6 +387,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
device = device or device_arg
|
||||
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
def module_is_sequentially_offloaded(module):
|
||||
@@ -410,10 +410,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
if device and torch.device(device).type == "cuda":
|
||||
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
# PR: https://github.com/huggingface/accelerate/pull/3223/
|
||||
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
|
||||
raise ValueError(
|
||||
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
|
||||
@@ -18,10 +18,11 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_bitsandbytes_available,
|
||||
@@ -47,6 +48,7 @@ def get_some_linear_layer(model):
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import BitsAndBytesConfig as BnbConfig
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
if is_torch_available():
|
||||
@@ -483,6 +485,47 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
|
||||
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_accelerate_version("<=", "1.1.1"),
|
||||
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
|
||||
strict=True,
|
||||
)
|
||||
def test_pipeline_cuda_placement_works_with_nf4(self):
|
||||
transformer_nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
transformer_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
quantization_config=transformer_nf4_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
text_encoder_3_nf4_config = BnbConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=text_encoder_3_nf4_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
# CUDA device placement works.
|
||||
pipeline_4bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
transformer=transformer_4bit,
|
||||
text_encoder_3=text_encoder_3_4bit,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
|
||||
del pipeline_4bit
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
|
||||
@@ -17,8 +17,10 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
|
||||
from diffusers.utils import is_accelerate_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_bitsandbytes_available,
|
||||
@@ -44,6 +46,7 @@ def get_some_linear_layer(model):
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import BitsAndBytesConfig as BnbConfig
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
if is_torch_available():
|
||||
@@ -432,6 +435,39 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_accelerate_version("<=", "1.1.1"),
|
||||
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
|
||||
strict=True,
|
||||
)
|
||||
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
|
||||
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
quantization_config=transformer_8bit_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
|
||||
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=text_encoder_3_8bit_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
# CUDA device placement works.
|
||||
pipeline_8bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
transformer=transformer_8bit,
|
||||
text_encoder_3=text_encoder_3_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
|
||||
del pipeline_8bit
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
|
||||
Reference in New Issue
Block a user