mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -447,21 +447,19 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
|
||||
# This class is not to be run as a test by itself. See the tests that follow this class
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
quant_method, quant_method_kwargs = None, None
|
||||
device = "cuda"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_model(self, device=None):
|
||||
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
@@ -497,15 +495,15 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def test_original_model_expected_slice(self):
|
||||
quantized_model = self.get_dummy_model(torch_device)
|
||||
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def check_serialization_expected_slice(self, expected_slice):
|
||||
quantized_model = self.get_dummy_model(self.device)
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
@@ -524,36 +522,33 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_serialization_expected_slice(self):
|
||||
self.check_serialization_expected_slice(self.serialized_expected_slice)
|
||||
def test_int_a8w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
def test_int_a8w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
def test_int_a16w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
|
||||
Reference in New Issue
Block a user