1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-09-11 09:11:09 +05:30
parent 9c24d1fc0f
commit cca03df7fc
13 changed files with 60 additions and 50 deletions

View File

@@ -40,7 +40,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class AuraFlowLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -40,7 +40,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class CogVideoXLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}

View File

@@ -47,7 +47,7 @@ class TokenizerWrapper:
@require_peft_backend
@skip_mps
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class CogView4LoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -53,7 +53,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
@@ -280,7 +280,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}

View File

@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class HunyuanVideoLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class LTXVideoLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class MochiLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class QwenImageLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class SanaLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler_kwargs = {}

View File

@@ -51,7 +51,7 @@ if is_accelerate_available():
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class SD3LoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}

View File

@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class WanLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -47,7 +47,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class WanVACELoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]

View File

@@ -129,17 +129,26 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
cached_non_lora_outputs = {}
cached_base_pipe_outputs = {}
@pytest.fixture(scope="class", autouse=True)
def cache_non_lora_outputs(self):
"""
This fixture will be executed once per test class and will populate
the cached_non_lora_outputs dictionary.
"""
def setUp(self):
self._cache_base_pipeline_output()
super().setUp()
@classmethod
def tearDownClass(cls):
cls.cached_base_pipe_outputs.clear()
def _cache_base_pipeline_output(self):
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
if cached_base_pipe_outs is not None and all(k in cached_base_pipe_outs for k in all_scheduler_names):
return
cached_base_pipe_outs = cached_base_pipe_outs or {}
for scheduler_cls in self.scheduler_classes:
# Check if the output for this scheduler is already cached to avoid re-running
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
if scheduler_cls.__name__ in cached_base_pipe_outs:
continue
components, _, _ = self.get_dummy_components(scheduler_cls)
@@ -151,11 +160,12 @@ class PeftLoraLoaderMixinTests:
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
# Ensures that there's no inconsistency when reusing the cache.
yield
self.cached_non_lora_outputs.clear()
type(self).cached_base_pipe_outputs = cached_base_pipe_outs
def get_base_pipe_output(self, scheduler_cls):
return type(self).cached_base_pipe_outputs[scheduler_cls.__name__]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
@@ -348,7 +358,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self):
@@ -363,7 +373,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -446,7 +456,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -502,7 +512,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -540,7 +550,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -572,7 +582,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -607,7 +617,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -658,7 +668,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -709,7 +719,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -752,7 +762,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -793,7 +803,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -837,7 +847,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -875,7 +885,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -954,7 +964,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1083,7 +1093,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1140,7 +1150,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1303,7 +1313,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1397,7 +1407,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1641,7 +1651,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1722,7 +1732,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1777,7 +1787,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_dora_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1909,7 +1919,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
original_out = self.get_base_pipe_output(scheduler_cls)
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1955,7 +1965,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2309,7 +2319,7 @@ class PeftLoraLoaderMixinTests:
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(
@@ -2359,7 +2369,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
output_no_lora = self.get_base_pipe_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)