mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] feat: add AoT compilation tests (#12203)
* feat: add a test for aot. * up
This commit is contained in:
@@ -2059,6 +2059,7 @@ class TorchCompileTesterMixin:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
@@ -2076,6 +2077,7 @@ class TorchCompileTesterMixin:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
@@ -2098,7 +2100,6 @@ class TorchCompileTesterMixin:
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
# TODO: Can test for other group offloading kwargs later if needed.
|
||||
group_offload_kwargs = {
|
||||
@@ -2111,11 +2112,11 @@ class TorchCompileTesterMixin:
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
@@ -2123,6 +2124,7 @@ class TorchCompileTesterMixin:
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
@@ -2130,6 +2132,26 @@ class TorchCompileTesterMixin:
|
||||
inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path)
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_2
|
||||
|
||||
Reference in New Issue
Block a user