mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] tighten compilation tests for quantization (#12002)
* tighten compilation tests for quantization * up * up
This commit is contained in:
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@require_bitsandbytes_version_greater("0.46.1")
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super().test_torch_compile()
|
||||
|
||||
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
|
||||
@@ -56,12 +56,18 @@ class QuantCompileTests:
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
# regional compilation is better for offloading.
|
||||
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
|
||||
if getattr(pipe.transformer, "_repeated_blocks"):
|
||||
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||
else:
|
||||
pipe.transformer.compile()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
Reference in New Issue
Block a user