From f1a93c765f4b292352e26fb10670373b8e5837f7 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:01:58 -0800 Subject: [PATCH] Add Flag to `PeftLoraLoaderMixinTests` to Enable/Disable Text Encoder LoRA Tests (#12962) * Improve incorrect LoRA format error message * Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests * Apply changes to LTX2LoraTests * Further improve incorrect LoRA format error msg following review --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/lora_pipeline.py | 40 ++++++++++----------- tests/lora/test_lora_layers_auraflow.py | 22 ++---------- tests/lora/test_lora_layers_cogvideox.py | 22 ++---------- tests/lora/test_lora_layers_cogview4.py | 22 ++---------- tests/lora/test_lora_layers_flux2.py | 22 ++---------- tests/lora/test_lora_layers_hunyuanvideo.py | 22 ++---------- tests/lora/test_lora_layers_ltx2.py | 26 ++------------ tests/lora/test_lora_layers_ltx_video.py | 22 ++---------- tests/lora/test_lora_layers_lumina2.py | 22 ++---------- tests/lora/test_lora_layers_mochi.py | 22 ++---------- tests/lora/test_lora_layers_qwenimage.py | 22 ++---------- tests/lora/test_lora_layers_sana.py | 22 ++---------- tests/lora/test_lora_layers_wan.py | 22 ++---------- tests/lora/test_lora_layers_wanvace.py | 22 ++---------- tests/lora/test_lora_layers_z_image.py | 22 ++---------- tests/lora/utils.py | 19 ++++++++++ 16 files changed, 67 insertions(+), 304 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5fc650a80d..24d1fd7b93 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_peft_state_dict = { k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") @@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56..78ef4ce151 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2..7bd54b77ca 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb63..e8ee6e7a7d 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 4ae189aceb..d970b7d784 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Flux2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a..e59bc5662f 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder_2", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py index 886ae70b7d..0a4b14454f 100644 --- a/tests/lora/test_lora_layers_ltx2.py +++ b/tests/lora/test_lora_layers_ltx2.py @@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 5, 32, 32, 3) @@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTX2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_save_pretrained_with_text_lora(self): - pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e51..095e5b577c 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33..da032229a7 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 4, 4, 3) @@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db..ee82541129 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 7, 16, 16, 3) @@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20..73fd026a67 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ) denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index a860b7b44f..97bf5cbba9 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference_denoiser(self): return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b41..5ae16ab4b9 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9..c8acaea9be 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index 35d1389d96..8432ea56a6 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in ZImage.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5fae6cac0a..efa49b9f48 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests: tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + supports_text_encoder_loras = True unet_kwargs = None transformer_cls = None @@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA. """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests: with different ranks and some adapters removed and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, _, _ = self.get_dummy_components() # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)