1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[chore] misc changes in the bnb tests for consistency. (#11355)

misc changes in the bnb tests for consistency.
This commit is contained in:
Sayak Paul
2025-06-02 08:41:10 -07:00
committed by GitHub
parent 3a31b291f1
commit d4dc4d7654
2 changed files with 4 additions and 4 deletions

View File

@@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_device_placement_works_with_nf4(self):
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
@@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests):
).to(torch_device)
# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
_ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit

View File

@@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests):
self.assertTrue(max_diff < 1e-2)
# 8bit models cannot be offloaded to CPU.
self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
# calling it again shouldn't be a problem
_ = self.pipeline_8bit(
prompt=self.prompt,
@@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests):
).to(device)
# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
_ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit