1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] tighten compilation tests for quantization (#12002)

* tighten compilation tests for quantization

* up

* up
This commit is contained in:
Sayak Paul
2025-08-07 10:13:14 +05:30
committed by GitHub
parent 5780776c8a
commit 061163142d
3 changed files with 13 additions and 2 deletions

View File

@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
@require_bitsandbytes_version_greater("0.46.1")
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super().test_torch_compile()

View File

@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
@pytest.mark.xfail(
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(torch_dtype=torch.float16)

View File

@@ -56,12 +56,18 @@ class QuantCompileTests:
pipe.transformer.compile(fullgraph=True)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
with torch._dynamo.config.patch(error_on_recompile=True):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
# regional compilation is better for offloading.
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
if getattr(pipe.transformer, "_repeated_blocks"):
pipe.transformer.compile_repeated_blocks(fullgraph=True)
else:
pipe.transformer.compile()
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)