mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
556 lines
23 KiB
Python
556 lines
23 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
import pytest
|
|
import safetensors.torch
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from diffusers.utils.import_utils import is_peft_available
|
|
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
|
|
|
from ...testing_utils import (
|
|
assert_tensors_close,
|
|
backend_empty_cache,
|
|
is_lora,
|
|
is_torch_compile,
|
|
require_peft_backend,
|
|
require_peft_version_greater,
|
|
require_torch_accelerator,
|
|
require_torch_version_greater,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
if is_peft_available():
|
|
from diffusers.loaders.peft import PeftAdapterMixin
|
|
|
|
|
|
def check_if_lora_correctly_set(model) -> bool:
|
|
"""
|
|
Check if LoRA layers are correctly set in the model.
|
|
|
|
Args:
|
|
model: The model to check
|
|
|
|
Returns:
|
|
bool: True if LoRA is correctly set, False otherwise
|
|
"""
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
return True
|
|
return False
|
|
|
|
|
|
@is_lora
|
|
@require_peft_backend
|
|
class LoraTesterMixin:
|
|
"""
|
|
Mixin class for testing LoRA/PEFT functionality on models.
|
|
|
|
Expected from config mixin:
|
|
- model_class: The model class to test
|
|
|
|
Expected methods from config mixin:
|
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: lora
|
|
Use `pytest -m "not lora"` to skip these tests
|
|
"""
|
|
|
|
def setup_method(self):
|
|
if not issubclass(self.model_class, PeftAdapterMixin):
|
|
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
|
|
|
@torch.no_grad()
|
|
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
|
|
from peft import LoraConfig
|
|
from peft.utils import get_peft_model_state_dict
|
|
|
|
init_dict = self.get_init_dict()
|
|
inputs_dict = self.get_dummy_inputs()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
torch.manual_seed(0)
|
|
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
denoiser_lora_config = LoraConfig(
|
|
r=rank,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
init_lora_weights=False,
|
|
use_dora=use_dora,
|
|
)
|
|
model.add_adapter(denoiser_lora_config)
|
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
|
|
|
torch.manual_seed(0)
|
|
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
|
|
"Output should differ with LoRA enabled"
|
|
)
|
|
|
|
model.save_lora_adapter(tmp_path)
|
|
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
|
"LoRA weights file not created"
|
|
)
|
|
|
|
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
|
|
|
model.unload_lora()
|
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
|
|
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
|
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
|
|
|
for k in state_dict_loaded:
|
|
loaded_v = state_dict_loaded[k]
|
|
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
|
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
|
|
|
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
|
|
|
torch.manual_seed(0)
|
|
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
|
|
"Output should differ with LoRA enabled"
|
|
)
|
|
assert_tensors_close(
|
|
outputs_with_lora,
|
|
outputs_with_lora_2,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
msg="Outputs should match before and after save/load",
|
|
)
|
|
|
|
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
|
from peft import LoraConfig
|
|
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
denoiser_lora_config = LoraConfig(
|
|
r=4,
|
|
lora_alpha=4,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
init_lora_weights=False,
|
|
use_dora=False,
|
|
)
|
|
model.add_adapter(denoiser_lora_config)
|
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
|
|
|
wrong_name = "foo"
|
|
with pytest.raises(ValueError) as exc_info:
|
|
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
|
|
|
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
|
|
|
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
|
from peft import LoraConfig
|
|
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
denoiser_lora_config = LoraConfig(
|
|
r=rank,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
init_lora_weights=False,
|
|
use_dora=use_dora,
|
|
)
|
|
model.add_adapter(denoiser_lora_config)
|
|
metadata = model.peft_config["default"].to_dict()
|
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
|
|
|
model.save_lora_adapter(tmp_path)
|
|
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
|
assert os.path.isfile(model_file), "LoRA weights file not created"
|
|
|
|
model.unload_lora()
|
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
|
|
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
|
parsed_metadata = model.peft_config["default_0"].to_dict()
|
|
check_if_dicts_are_equal(metadata, parsed_metadata)
|
|
|
|
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
|
from peft import LoraConfig
|
|
|
|
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
|
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
denoiser_lora_config = LoraConfig(
|
|
r=4,
|
|
lora_alpha=4,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
init_lora_weights=False,
|
|
use_dora=False,
|
|
)
|
|
model.add_adapter(denoiser_lora_config)
|
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
|
|
|
model.save_lora_adapter(tmp_path)
|
|
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
|
assert os.path.isfile(model_file), "LoRA weights file not created"
|
|
|
|
# Perturb the metadata in the state dict
|
|
loaded_state_dict = safetensors.torch.load_file(model_file)
|
|
metadata = {"format": "pt"}
|
|
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
|
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
|
for key, value in lora_adapter_metadata.items():
|
|
if isinstance(value, set):
|
|
lora_adapter_metadata[key] = list(value)
|
|
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
|
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
|
|
|
model.unload_lora()
|
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
|
|
|
with pytest.raises(TypeError) as exc_info:
|
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
|
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
|
|
|
|
|
@is_lora
|
|
@is_torch_compile
|
|
@require_peft_backend
|
|
@require_peft_version_greater("0.14.0")
|
|
@require_torch_version_greater("2.7.1")
|
|
@require_torch_accelerator
|
|
class LoraHotSwappingForModelTesterMixin:
|
|
"""
|
|
Mixin class for testing LoRA hot swapping functionality on models.
|
|
|
|
Test that hotswapping does not result in recompilation on the model directly.
|
|
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
|
and is extensively tested there. The goal of this test is specifically to ensure that
|
|
hotswapping with diffusers does not require recompilation.
|
|
|
|
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
|
for the analogous PEFT test.
|
|
|
|
Expected from config mixin:
|
|
- model_class: The model class to test
|
|
|
|
Optional properties:
|
|
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
|
|
|
|
Expected methods from config mixin:
|
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest marks: lora, torch_compile
|
|
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
|
"""
|
|
|
|
@property
|
|
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
|
"""Optional list of (height, width) tuples for dynamic compilation tests."""
|
|
return None
|
|
|
|
def setup_method(self):
|
|
if not issubclass(self.model_class, PeftAdapterMixin):
|
|
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
|
|
|
def teardown_method(self):
|
|
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
|
# there will be recompilation errors, as torch caches the model when run in the same process.
|
|
torch.compiler.reset()
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
|
from peft import LoraConfig
|
|
|
|
lora_config = LoraConfig(
|
|
r=lora_rank,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=target_modules,
|
|
init_lora_weights=False,
|
|
use_dora=False,
|
|
)
|
|
return lora_config
|
|
|
|
def _get_linear_module_name_other_than_attn(self, model):
|
|
linear_names = [
|
|
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
|
]
|
|
return linear_names[0]
|
|
|
|
def _check_model_hotswap(
|
|
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
|
|
):
|
|
"""
|
|
Check that hotswapping works on a model.
|
|
|
|
Steps:
|
|
- create 2 LoRA adapters and save them
|
|
- load the first adapter
|
|
- hotswap the second adapter
|
|
- check that the outputs are correct
|
|
- optionally compile the model
|
|
- optionally check if recompilations happen on different shapes
|
|
|
|
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
|
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
|
fine.
|
|
"""
|
|
different_shapes = self.different_shapes_for_compilation
|
|
# create 2 adapters with different ranks and alphas
|
|
torch.manual_seed(0)
|
|
init_dict = self.get_init_dict()
|
|
inputs_dict = self.get_dummy_inputs()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
alpha0, alpha1 = rank0, rank1
|
|
max_rank = max([rank0, rank1])
|
|
if target_modules1 is None:
|
|
target_modules1 = target_modules0[:]
|
|
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
|
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
|
|
|
model.add_adapter(lora_config0, adapter_name="adapter0")
|
|
with torch.inference_mode():
|
|
torch.manual_seed(0)
|
|
output0_before = model(**inputs_dict)["sample"]
|
|
|
|
model.add_adapter(lora_config1, adapter_name="adapter1")
|
|
model.set_adapter("adapter1")
|
|
with torch.inference_mode():
|
|
torch.manual_seed(0)
|
|
output1_before = model(**inputs_dict)["sample"]
|
|
|
|
# sanity checks:
|
|
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
|
|
assert not (output0_before == 0).all()
|
|
assert not (output1_before == 0).all()
|
|
|
|
# save the adapter checkpoints
|
|
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
|
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
|
del model
|
|
|
|
# load the first adapter
|
|
torch.manual_seed(0)
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
if do_compile or (rank0 != rank1):
|
|
# no need to prepare if the model is not compiled or if the ranks are identical
|
|
model.enable_lora_hotswap(target_rank=max_rank)
|
|
|
|
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
|
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
|
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
|
|
|
if do_compile:
|
|
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
|
|
|
with torch.inference_mode():
|
|
# additionally check if dynamic compilation works.
|
|
if different_shapes is not None:
|
|
for height, width in different_shapes:
|
|
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
|
_ = model(**new_inputs_dict)
|
|
else:
|
|
output0_after = model(**inputs_dict)["sample"]
|
|
assert_tensors_close(
|
|
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
|
|
)
|
|
|
|
# hotswap the 2nd adapter
|
|
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
|
|
|
# we need to call forward to potentially trigger recompilation
|
|
with torch.inference_mode():
|
|
if different_shapes is not None:
|
|
for height, width in different_shapes:
|
|
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
|
_ = model(**new_inputs_dict)
|
|
else:
|
|
output1_after = model(**inputs_dict)["sample"]
|
|
assert_tensors_close(
|
|
output1_before,
|
|
output1_after,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
msg="Output mismatch after hotswapping to adapter1",
|
|
)
|
|
|
|
# check error when not passing valid adapter name
|
|
name = "does-not-exist"
|
|
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
|
with pytest.raises(ValueError, match=re.escape(msg)):
|
|
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
|
self._check_model_hotswap(
|
|
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
|
)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
|
# It's important to add this context to raise an error on recompilation
|
|
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
|
self._check_model_hotswap(
|
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
|
)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
|
if "unet" not in self.model_class.__name__.lower():
|
|
pytest.skip("Test only applies to UNet.")
|
|
|
|
# It's important to add this context to raise an error on recompilation
|
|
target_modules = ["conv", "conv1", "conv2"]
|
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
|
self._check_model_hotswap(
|
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
|
)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
|
if "unet" not in self.model_class.__name__.lower():
|
|
pytest.skip("Test only applies to UNet.")
|
|
|
|
# It's important to add this context to raise an error on recompilation
|
|
target_modules = ["to_q", "conv"]
|
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
|
self._check_model_hotswap(
|
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
|
)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
|
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
|
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
|
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
|
# block.
|
|
target_modules = ["to_q"]
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict)
|
|
|
|
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
|
del model
|
|
|
|
# It's important to add this context to raise an error on recompilation
|
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
|
self._check_model_hotswap(
|
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
|
)
|
|
|
|
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
|
# ensure that enable_lora_hotswap is called before loading the first adapter
|
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
model.add_adapter(lora_config)
|
|
|
|
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
|
with pytest.raises(RuntimeError, match=msg):
|
|
model.enable_lora_hotswap(target_rank=32)
|
|
|
|
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
|
# ensure that enable_lora_hotswap is called before loading the first adapter
|
|
import logging
|
|
|
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
model.add_adapter(lora_config)
|
|
msg = (
|
|
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
|
)
|
|
with caplog.at_level(logging.WARNING):
|
|
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
|
assert any(msg in record.message for record in caplog.records)
|
|
|
|
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
|
# check possibility to ignore the error/warning
|
|
import logging
|
|
|
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
model.add_adapter(lora_config)
|
|
with caplog.at_level(logging.WARNING):
|
|
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
|
assert len(caplog.records) == 0
|
|
|
|
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
|
# check that wrong argument value raises an error
|
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
model.add_adapter(lora_config)
|
|
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
|
with pytest.raises(ValueError, match=msg):
|
|
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
|
|
|
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
|
# check the error and log
|
|
import logging
|
|
|
|
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
|
target_modules0 = ["to_q"]
|
|
target_modules1 = ["to_q", "to_k"]
|
|
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
|
with caplog.at_level(logging.ERROR):
|
|
self._check_model_hotswap(
|
|
tmp_path,
|
|
do_compile=True,
|
|
rank0=8,
|
|
rank1=8,
|
|
target_modules0=target_modules0,
|
|
target_modules1=target_modules1,
|
|
)
|
|
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
|
|
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
|
@require_torch_version_greater("2.7.1")
|
|
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
|
different_shapes_for_compilation = self.different_shapes_for_compilation
|
|
if different_shapes_for_compilation is None:
|
|
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
|
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
|
# variable to represent input sizes that are the same. For more details,
|
|
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
|
torch.fx.experimental._config.use_duck_shape = False
|
|
|
|
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
|
self._check_model_hotswap(
|
|
tmp_path,
|
|
do_compile=True,
|
|
rank0=rank0,
|
|
rank1=rank1,
|
|
target_modules0=target_modules,
|
|
)
|