mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[chore] misc changes in the bnb tests for consistency. (#11355)
misc changes in the bnb tests for consistency.
This commit is contained in:
@@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
|
||||
strict=True,
|
||||
)
|
||||
def test_pipeline_device_placement_works_with_nf4(self):
|
||||
def test_pipeline_cuda_placement_works_with_nf4(self):
|
||||
transformer_nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
).to(torch_device)
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
_ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
|
||||
|
||||
del pipeline_4bit
|
||||
|
||||
|
||||
@@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
self.assertTrue(max_diff < 1e-2)
|
||||
|
||||
# 8bit models cannot be offloaded to CPU.
|
||||
self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
|
||||
self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
|
||||
# calling it again shouldn't be a problem
|
||||
_ = self.pipeline_8bit(
|
||||
prompt=self.prompt,
|
||||
@@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
).to(device)
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
_ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
|
||||
|
||||
del pipeline_8bit
|
||||
|
||||
|
||||
Reference in New Issue
Block a user