diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bd5584296a..acc6d30b79 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests): reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.", strict=True, ) - def test_pipeline_device_placement_works_with_nf4(self): + def test_pipeline_cuda_placement_works_with_nf4(self): transformer_nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", @@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests): ).to(torch_device) # Check if inference works. - _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2) + _ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2) del pipeline_4bit diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index f987097799..7abb907ff9 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests): self.assertTrue(max_diff < 1e-2) # 8bit models cannot be offloaded to CPU. - self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device) # calling it again shouldn't be a problem _ = self.pipeline_8bit( prompt=self.prompt, @@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests): ).to(device) # Check if inference works. - _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) + _ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2) del pipeline_8bit