mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] add test for hotswapping + compilation on resolution changes (#11825)
* add resolution changes tests to hotswapping test suite. * fixes * docs * explain duck shapes * fix
This commit is contained in:
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
|
||||
> [!TIP]
|
||||
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
|
||||
|
||||
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
|
||||
|
||||
## Merge
|
||||
|
||||
@@ -1350,7 +1350,6 @@ class ModelTesterMixin:
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def tearDown(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
@@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead")
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
|
||||
with torch.inference_mode():
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
@@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
|
||||
@@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
Reference in New Issue
Block a user