mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] tests for compilation + quantization (bnb) (#11672)
* start adding compilation tests for quantization. * fixes * make common utility. * modularize. * add group offloading+compile * xfail * update * Update tests/quantization/test_torch_compile_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -291,6 +291,18 @@ def require_torch_version_greater_equal(torch_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_version_greater(torch_version):
|
||||
"""Decorator marking a test that requires torch with a specific version greater."""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers import (
|
||||
FluxTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
)
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -44,11 +45,14 @@ from diffusers.utils.testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
|
||||
@@ -855,3 +859,26 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
def test_fp4_double_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
|
||||
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@@ -46,11 +46,14 @@ from diffusers.utils.testing_utils import (
|
||||
require_peft_version_greater,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater_equal,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
|
||||
@@ -821,3 +824,27 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
out_0 = self.model_0(**inputs)[0]
|
||||
out_1 = model_1(**inputs)[0]
|
||||
self.assertTrue(torch.equal(out_0, out_1))
|
||||
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
87
tests/quantization/test_torch_compile_utils.py
Normal file
87
tests/quantization/test_torch_compile_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
quantization_config = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.compiler.reset()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.compiler.reset()
|
||||
|
||||
def _init_pipeline(self, quantization_config, torch_dtype):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return pipe
|
||||
|
||||
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
|
||||
# import to ensure fullgraph True
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
pipe.transformer.enable_group_offload(**group_offload_kwargs)
|
||||
pipe.transformer.compile()
|
||||
for name, component in pipe.components.items():
|
||||
if name != "transformer" and isinstance(component, torch.nn.Module):
|
||||
if torch.device(component.device).type == "cpu":
|
||||
component.to("cuda")
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
Reference in New Issue
Block a user