mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] Enable more general testing for torch.compile() with LoRA hotswapping (#11322)
* refactor hotswap tester. * fix seeds.. * add to nightly ci. * move comment. * move to nightly
This commit is contained in:
1
.github/workflows/nightly_tests.yml
vendored
1
.github/workflows/nightly_tests.yml
vendored
@@ -142,6 +142,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
|
||||
@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import (
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
floats_tensor,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -1754,7 +1753,7 @@ class TorchCompileTesterMixin:
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@is_torch_compile
|
||||
class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""Test that hotswapping does not result in recompilation on the model directly.
|
||||
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
|
||||
@@ -1775,48 +1774,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_small_unet(self):
|
||||
# from diffusers UNet2DConditionModelTests
|
||||
torch.manual_seed(0)
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 8,
|
||||
"attention_head_dim": 2,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
model = UNet2DConditionModel(**init_dict)
|
||||
return model.to(torch_device)
|
||||
|
||||
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
# from diffusers test_models_unet_2d_condition.py
|
||||
from peft import LoraConfig
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
return lora_config
|
||||
|
||||
def get_dummy_input(self):
|
||||
# from UNet2DConditionModelTests
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
def get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
@@ -1834,23 +1809,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
fine.
|
||||
"""
|
||||
# create 2 adapters with different ranks and alphas
|
||||
dummy_input = self.get_dummy_input()
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
|
||||
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
output0_before = unet(**dummy_input)["sample"]
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
unet.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
unet.set_adapter("adapter1")
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
output1_before = unet(**dummy_input)["sample"]
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
tol = 5e-3
|
||||
@@ -1860,40 +1839,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
# save the adapter checkpoints
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del unet
|
||||
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
unet = self.get_small_unet()
|
||||
torch.manual_seed(0)
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
unet.enable_lora_hotswap(target_rank=max_rank)
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
|
||||
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
unet = torch.compile(unet, mode="reduce-overhead")
|
||||
model = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
with torch.inference_mode():
|
||||
output0_after = unet(**dummy_input)["sample"]
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
output1_after = unet(**dummy_input)["sample"]
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_model(self, rank0, rank1):
|
||||
@@ -1910,6 +1892,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
return
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
@@ -1917,52 +1902,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
return
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self.get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32)
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with self.assertLogs(logger=logger, level="WARNING") as cm:
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in log for log in cm.output)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self):
|
||||
# check the error and log
|
||||
|
||||
@@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class FluxTransformerTests(
|
||||
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
|
||||
):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class UNet2DConditionModelTests(
|
||||
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
|
||||
Reference in New Issue
Block a user