1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] tests for compilation + quantization (bnb) (#11672)

* start adding compilation tests for quantization.

* fixes

* make common utility.

* modularize.

* add group offloading+compile

* xfail

* update

* Update tests/quantization/test_torch_compile_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Sayak Paul
2025-06-11 21:14:24 +05:30
committed by GitHub
parent 33e636cea5
commit b6f7933044
4 changed files with 153 additions and 0 deletions

View File

@@ -291,6 +291,18 @@ def require_torch_version_greater_equal(torch_version):
return decorator
def require_torch_version_greater(torch_version):
"""Decorator marking a test that requires torch with a specific version greater."""
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
)(test_case)
return decorator
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

View File

@@ -30,6 +30,7 @@ from diffusers import (
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -44,11 +45,14 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch,
require_torch_accelerator,
require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -855,3 +859,26 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
@require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

View File

@@ -46,11 +46,14 @@ from diffusers.utils.testing_utils import (
require_peft_version_greater,
require_torch,
require_torch_accelerator,
require_torch_version_greater_equal,
require_transformers_version_greater,
slow,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -821,3 +824,27 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
out_0 = self.model_0(**inputs)[0]
out_1 = model_1(**inputs)[0]
self.assertTrue(torch.equal(out_0, out_1))
@require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)

View File

@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
quantization_config = None
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()
def _init_pipeline(self, quantization_config, torch_dtype):
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)
return pipe
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
pipe.transformer.compile(fullgraph=True)
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
component.to("cuda")
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)