From 87f83d3dd9247affcc0912175b2eff5f4a56e75a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 1 Jul 2025 09:40:34 +0530 Subject: [PATCH] [tests] add test for hotswapping + compilation on resolution changes (#11825) * add resolution changes tests to hotswapping test suite. * fixes * docs * explain duck shapes * fix --- .../en/tutorials/using_peft_for_inference.md | 2 + tests/models/test_modeling_common.py | 46 ++++++++++++++++--- .../test_models_transformer_flux.py | 4 ++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index b18977720c..5a382c1c94 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -315,6 +315,8 @@ pipeline.load_lora_weights( > [!TIP] > Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example. +If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details. + There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs. ## Merge diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index dcc7ae16a4..def81ecd64 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1350,7 +1350,6 @@ class ModelTesterMixin: new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin: """ + different_shapes_for_compilation = None + def tearDown(self): # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, # there will be recompilation errors, as torch caches the model when run in the same process. @@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin: - hotswap the second adapter - check that the outputs are correct - optionally compile the model + - optionally check if recompilations happen on different shapes Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is fine. """ + different_shapes = self.different_shapes_for_compilation # create 2 adapters with different ranks and alphas torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin: model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - model = torch.compile(model, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) with torch.inference_mode(): - output0_after = model(**inputs_dict)["sample"] - assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = model(**inputs_dict)["sample"] - assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" @@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin: do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap( + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4552b2e1f5..68b5c02bc0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width)