From ee40088fe5437f8ed65ec96a22250149e4f334cc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 23 Jun 2025 10:47:36 +0800 Subject: [PATCH] enable deterministic in bnb 4 bit tests (#11738) * enable deterministic in bnb 4 bit tests Signed-off-by: jiqing-feng * fix 8bit test Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- tests/quantization/bnb/test_4bit.py | 5 ++++- tests/quantization/bnb/test_mixed_int8.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index e173a4c721..bdb8920a39 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -96,6 +96,10 @@ class Base4bitTests(unittest.TestCase): num_inference_steps = 10 seed = 0 + @classmethod + def setUpClass(cls): + torch.use_deterministic_algorithms(True) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", @@ -480,7 +484,6 @@ class SlowBnb4BitTests(Base4bitTests): r""" Test that loading the model and unquantize it produce correct results. """ - torch.use_deterministic_algorithms(True) self.pipeline_4bit.transformer.dequantize() output = self.pipeline_4bit( prompt=self.prompt, diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bb7b12de60..d048b0b7db 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -97,6 +97,10 @@ class Base8bitTests(unittest.TestCase): num_inference_steps = 10 seed = 0 + @classmethod + def setUpClass(cls): + torch.use_deterministic_algorithms(True) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", @@ -485,7 +489,6 @@ class SlowBnb8bitTests(Base8bitTests): r""" Test that loading the model and unquantize it produce correct results. """ - torch.use_deterministic_algorithms(True) self.pipeline_8bit.transformer.dequantize() output = self.pipeline_8bit( prompt=self.prompt,