mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
TorchAO compile + offloading tests (#11697)
* update * update * update * update * update * user property instead
This commit is contained in:
@@ -866,15 +866,17 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
@@ -883,5 +885,7 @@ class Bnb4BitCompileTests(QuantCompileTests):
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, use_stream=True
|
||||
)
|
||||
|
||||
@@ -831,11 +831,13 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
@@ -847,7 +849,7 @@ class Bnb8BitCompileTests(QuantCompileTests):
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
|
||||
)
|
||||
|
||||
@@ -24,7 +24,11 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
quantization_config = None
|
||||
@property
|
||||
def quantization_config(self):
|
||||
raise NotImplementedError(
|
||||
"This property should be implemented in the subclass to return the appropriate quantization config."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -64,7 +68,9 @@ class QuantCompileTests(unittest.TestCase):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
def _test_torch_compile_with_group_offload_leaf(
|
||||
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
|
||||
):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
@@ -72,8 +78,7 @@ class QuantCompileTests(unittest.TestCase):
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
"use_stream": use_stream,
|
||||
}
|
||||
pipe.transformer.enable_group_offload(**group_offload_kwargs)
|
||||
pipe.transformer.compile()
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -29,6 +30,7 @@ from diffusers import (
|
||||
TorchAoConfig,
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_synchronize,
|
||||
@@ -44,6 +46,8 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
@@ -625,6 +629,53 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoCompileTest(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
||||
},
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
For `use_stream=True`:
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
@parameterized.expand([False, True])
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# When running with compilation, the error ends up being different:
|
||||
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# Looks like something that will have to be looked into upstream.
|
||||
# for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user