1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] feat: add AoT compilation tests (#12203)

* feat: add a test for aot.

* up
This commit is contained in:
Sayak Paul
2025-09-03 11:15:27 +05:30
committed by GitHub
parent 4acbfbf13b
commit ffc8c0c1e1

View File

@@ -2059,6 +2059,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
@@ -2076,6 +2077,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
@@ -2098,7 +2100,6 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
@@ -2111,11 +2112,11 @@ class TorchCompileTesterMixin:
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@require_torch_version_greater("2.7.1")
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
@@ -2123,6 +2124,7 @@ class TorchCompileTesterMixin:
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
@@ -2130,6 +2132,26 @@ class TorchCompileTesterMixin:
inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**inputs_dict)
def test_compile_works_with_aot(self):
from torch._inductor.package import load_package
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path)
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@slow
@require_torch_2