1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] Refactor TorchAO serialization fast tests (#10271)

refactor
This commit is contained in:
Aryan
2024-12-23 11:04:57 +05:30
committed by GitHub
parent 6a970a45c5
commit 02c777c065

View File

@@ -447,21 +447,19 @@ class TorchAoTest(unittest.TestCase):
self.get_dummy_components(TorchAoConfig("int42"))
# This class is not to be run as a test by itself. See the tests that follow this class
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
device = "cuda"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_model(self, device=None):
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
@@ -497,15 +495,15 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep": timestep,
}
def test_original_model_expected_slice(self):
quantized_model = self.get_dummy_model(torch_device)
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def check_serialization_expected_slice(self, expected_slice):
quantized_model = self.get_dummy_model(self.device)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
@@ -524,36 +522,33 @@ class TorchAoSerializationTest(unittest.TestCase):
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_serialization_expected_slice(self):
self.check_serialization_expected_slice(self.serialized_expected_slice)
def test_int_a8w8_cuda(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
def test_int_a16w8_cuda(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
def test_int_a8w8_cpu(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
def test_int_a16w8_cpu(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners