From ec866f5de82c3ffafdbdb1bb1e861f5326ddb0a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 Oct 2025 14:25:54 +0530 Subject: [PATCH] tempfile is now a fixture. --- tests/lora/test_lora_layers_cogview4.py | 10 +- tests/lora/test_lora_layers_flux.py | 70 ++-- tests/lora/test_lora_layers_wanvace.py | 47 ++- tests/lora/utils.py | 466 +++++++++++------------- 4 files changed, 276 insertions(+), 317 deletions(-) diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index de732b8526..3a39c44a37 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -import tempfile import unittest import numpy as np @@ -119,7 +118,7 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests): def test_simple_inference_with_text_denoiser_lora_unfused(self): super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - def test_simple_inference_save_pretrained(self): + def test_simple_inference_save_pretrained(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ @@ -131,11 +130,10 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests): images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname) - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index f75a7b3777..7c230308ae 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -16,7 +16,6 @@ import copy import gc import os import sys -import tempfile import unittest import numpy as np @@ -114,7 +113,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): return noise, input_ids, pipeline_inputs - def test_with_alpha_in_state_dict(self): + def test_with_alpha_in_state_dict(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -126,24 +125,23 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - # modify the state dict to have alpha values following - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors - state_dict_with_alpha = safetensors.torch.load_file( - os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") - ) - alpha_dict = {} - for k, v in state_dict_with_alpha.items(): - if "transformer" in k and "to_k" in k and ("lora_A" in k): - alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) - state_dict_with_alpha.update(alpha_dict) + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors + state_dict_with_alpha = safetensors.torch.load_file( + os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") + ) + alpha_dict = {} + for k, v in state_dict_with_alpha.items(): + if "transformer" in k and "to_k" in k and ("lora_A" in k): + alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) + state_dict_with_alpha.update(alpha_dict) images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" @@ -156,7 +154,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): ) assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) - def test_lora_expansion_works_for_absent_keys(self, base_pipe_output): + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -175,16 +173,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) @@ -200,7 +197,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): "LoRA should lead to different results.", ) - def test_lora_expansion_works_for_extra_keys(self, base_pipe_output): + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -217,16 +214,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests): np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), "LoRA should lead to different results.", ) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.unload_lora_weights() - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + pipe.unload_lora_weights() + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") pipe.set_adapters(["one", "two"]) assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 60246ad2bc..9c319b9952 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -14,7 +14,6 @@ import os import sys -import tempfile import unittest import numpy as np @@ -163,7 +162,7 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests): super().test_layerwise_casting_inference_denoiser() @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules_wanvace(self, base_pipe_output): + def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname): exclude_module_name = "vace_blocks.0.proj_out" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) @@ -183,30 +182,26 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests): assert any("proj_out" in k for k in state_dict_from_model) output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts) - pipe.unload_lora_weights() + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + pipe.unload_lora_weights() - # Check in the loaded state dict. - loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - assert not any(exclude_module_name in k for k in loaded_state_dict) - assert any("proj_out" in k for k in loaded_state_dict) + # Check in the loaded state dict. + loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert not any(exclude_module_name in k for k in loaded_state_dict) + assert any("proj_out" in k for k in loaded_state_dict) - # Check in the state dict obtained after loading LoRA. - pipe.load_lora_weights(tmpdir) - state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") - assert not any(exclude_module_name in k for k in state_dict_from_model) - assert any("proj_out" in k for k in state_dict_from_model) + # Check in the state dict obtained after loading LoRA. + pipe.load_lora_weights(tmpdirname) + state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( - "LoRA should change outputs." - ) - assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( - "Lora outputs should match." - ) - - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - super().test_simple_inference_with_text_denoiser_lora_and_scale() + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( + "LoRA should change outputs." + ) + assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( + "Lora outputs should match." + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 0b9e1e0152..bc879f7691 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -15,7 +15,6 @@ import inspect import os import re -import tempfile from itertools import product import numpy as np @@ -122,10 +121,18 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + @property + def output_shape(self): + raise NotImplementedError + @pytest.fixture(scope="class") def base_pipe_output(self): return self._compute_baseline_output() + @pytest.fixture(scope="function") + def tmpdirname(self, tmp_path_factory): + return tmp_path_factory.mktemp("tmp") + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") @@ -211,10 +218,6 @@ class PeftLoraLoaderMixinTests: return pipeline_components, text_lora_config, denoiser_lora_config - @property - def output_shape(self): - raise NotImplementedError - def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 @@ -235,6 +238,23 @@ class PeftLoraLoaderMixinTests: return (noise, input_ids, pipeline_inputs) + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): + if text_lora_config is not None: + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + else: + denoiser = None + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + return pipe, denoiser + def _compute_baseline_output(self): components, _, _ = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) @@ -286,23 +306,6 @@ class PeftLoraLoaderMixinTests: return modules_to_save - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): - if text_lora_config is not None: - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - if denoiser_lora_config is not None: - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - else: - denoiser = None - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - return pipe, denoiser - def test_simple_inference(self, base_pipe_output): """ Tests a simple inference and makes sure it works as expected @@ -375,7 +378,7 @@ class PeftLoraLoaderMixinTests: @require_peft_version_greater("0.13.1") @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self): + def test_low_cpu_mem_usage_with_loading(self, tmpdirname): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -386,34 +389,31 @@ class PeftLoraLoaderMixinTests: pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results." - ) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." + ) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose( - images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001 - ), "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." + ) def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output): """ @@ -498,7 +498,7 @@ class PeftLoraLoaderMixinTests: "Fused lora should change the output" ) - def test_simple_inference_with_text_lora_save_load(self): + def test_simple_inference_with_text_lora_save_load(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA. """ @@ -511,16 +511,13 @@ class PeftLoraLoaderMixinTests: pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" @@ -577,7 +574,7 @@ class PeftLoraLoaderMixinTests: "Removing adapters should change the output" ) - def test_simple_inference_save_pretrained_with_text_lora(self): + def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ @@ -590,10 +587,9 @@ class PeftLoraLoaderMixinTests: pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) + pipe.save_pretrained(tmpdirname) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) if "text_encoder" in self.pipeline_class._lora_loadable_modules: assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), ( @@ -610,7 +606,7 @@ class PeftLoraLoaderMixinTests: "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_save_load(self): + def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ @@ -623,15 +619,12 @@ class PeftLoraLoaderMixinTests: pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" @@ -1486,7 +1479,7 @@ class PeftLoraLoaderMixinTests: "DoRA lora should change the output", ) - def test_missing_keys_warning(self): + def test_missing_keys_warning(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1496,15 +1489,12 @@ class PeftLoraLoaderMixinTests: denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] @@ -1516,7 +1506,7 @@ class PeftLoraLoaderMixinTests: component = list({k.split(".")[0] for k in state_dict})[0] assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", "")) - def test_unexpected_keys_warning(self): + def test_unexpected_keys_warning(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -1526,15 +1516,12 @@ class PeftLoraLoaderMixinTests: denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) @@ -1616,7 +1603,7 @@ class PeftLoraLoaderMixinTests: ) assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - def test_set_adapters_match_attention_kwargs(self, base_pipe_output): + def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, denoiser_lora_config = self.get_dummy_components() @@ -1645,31 +1632,28 @@ class PeftLoraLoaderMixinTests: "Lora + scale should match the output of `set_adapters()`." ) - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - for module_name, module in modules_to_save.items(): - assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Lora + scale should change the output" - ) - assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results as attention_kwargs." - ) - assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( - "Loading from saved checkpoints should give same results as set_adapters()." - ) + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Lora + scale should change the output" + ) + assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as attention_kwargs." + ) + assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results as set_adapters()." + ) @require_peft_version_greater("0.13.2") def test_lora_B_bias(self, base_pipe_output): @@ -1806,7 +1790,7 @@ class PeftLoraLoaderMixinTests: pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") - def test_layerwise_casting_peft_input_autocast_denoiser(self): + def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname): """ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise @@ -1865,77 +1849,73 @@ class PeftLoraLoaderMixinTests: _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - apply_layerwise_casting( - denoiser, - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - skip_modules_pattern=patterns_to_check, - ) - check_module(denoiser) + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + apply_layerwise_casting( + denoiser, + storage_dtype=storage_dtype, + compute_dtype=compute_dtype, + skip_modules_pattern=patterns_to_check, + ) + check_module(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe(**inputs, generator=torch.manual_seed(0))[0] + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) - if len(out) == 3: - (_, _, parsed_metadata) = out - elif len(out) == 2: - (_, parsed_metadata) = out - denoiser_key = ( - f"{self.pipeline_class.transformer_name}" - if self.transformer_kwargs is not None - else f"{self.pipeline_class.unet_name}" - ) - assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True) + if len(out) == 3: + (_, _, parsed_metadata) = out + elif len(out) == 2: + (_, parsed_metadata) = out + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - text_encoder_key = self.pipeline_class.text_encoder_name - assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key - ) - - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_key = "text_encoder_2" - assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe = self.pipeline_class(**components).to(torch_device) @@ -1946,18 +1926,15 @@ class PeftLoraLoaderMixinTests: ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), ( - "Lora outputs should match." - ) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname) + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), "Lora outputs should match." def test_lora_unload_add_adapter(self): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" @@ -1977,7 +1954,7 @@ class PeftLoraLoaderMixinTests: ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_inference_load_delete_load_adapters(self, base_pipe_output): + def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname): """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -2002,22 +1979,21 @@ class PeftLoraLoaderMixinTests: output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - pipe.delete_adapters(pipe.get_active_adapters()[0]) - output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) - assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) + pipe.delete_adapters(pipe.get_active_adapters()[0]) + output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001) + assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001) - pipe.load_lora_weights(tmpdirname) - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) + pipe.load_lora_weights(tmpdirname) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001) - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device @@ -2031,59 +2007,56 @@ class PeftLoraLoaderMixinTests: denoiser.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - denoiser.enable_group_offload( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=1, - use_stream=use_stream, - ) - for _, component in pipe.components.items(): - if isinstance(component, torch.nn.Module): - component.to(torch_device) + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) - group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_1 is not None + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_1 is not None - output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.unload_lora_weights() - group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_2 is not None + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_2 is not None - output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) - assert group_offload_hook_3 is not None + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_3 is not None - output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001) @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname): for cls in inspect.getmro(self.__class__): if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: return - self._test_group_offloading_inference_denoiser(offload_type, use_stream) + self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname) @require_torch_accelerator - def test_lora_loading_model_cpu_offload(self): + def test_lora_loading_model_cpu_offload(self, tmpdirname): components, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe = self.pipeline_class(**components) @@ -2096,18 +2069,15 @@ class PeftLoraLoaderMixinTests: output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.enable_model_cpu_offload(device=torch_device) - pipe.load_lora_weights(tmpdirname) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.enable_model_cpu_offload(device=torch_device) + pipe.load_lora_weights(tmpdirname) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001)