1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

tempfile is now a fixture.

This commit is contained in:
sayakpaul
2025-10-03 14:25:54 +05:30
parent 7b4bcce602
commit ec866f5de8
4 changed files with 276 additions and 317 deletions

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import sys
import tempfile
import unittest
import numpy as np
@@ -119,7 +118,7 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_save_pretrained(self):
def test_simple_inference_save_pretrained(self, tmpdirname):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
@@ -131,11 +130,10 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]

View File

@@ -16,7 +16,6 @@ import copy
import gc
import os
import sys
import tempfile
import unittest
import numpy as np
@@ -114,7 +113,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self):
def test_with_alpha_in_state_dict(self, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -126,24 +125,23 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
if "transformer" in k and "to_k" in k and ("lora_A" in k):
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
if "transformer" in k and "to_k" in k and ("lora_A" in k):
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
@@ -156,7 +154,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
)
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output):
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -175,16 +173,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001),
"LoRA should lead to different results.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"])
@@ -200,7 +197,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
"LoRA should lead to different results.",
)
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output):
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -217,16 +214,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001),
"LoRA should lead to different results.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.unload_lora_weights()
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.set_adapters(["one", "two"])
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"

View File

@@ -14,7 +14,6 @@
import os
import sys
import tempfile
import unittest
import numpy as np
@@ -163,7 +162,7 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
super().test_layerwise_casting_inference_denoiser()
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self, base_pipe_output):
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname):
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
@@ -183,30 +182,26 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
pipe.unload_lora_weights()
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
pipe.unload_lora_weights()
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
assert not any(exclude_module_name in k for k in loaded_state_dict)
assert any("proj_out" in k for k in loaded_state_dict)
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert not any(exclude_module_name in k for k in loaded_state_dict)
assert any("proj_out" in k for k in loaded_state_dict)
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdir)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdirname)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
"LoRA should change outputs."
)
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
"Lora outputs should match."
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
"LoRA should change outputs."
)
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
"Lora outputs should match."
)

View File

@@ -15,7 +15,6 @@
import inspect
import os
import re
import tempfile
from itertools import product
import numpy as np
@@ -122,10 +121,18 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
@property
def output_shape(self):
raise NotImplementedError
@pytest.fixture(scope="class")
def base_pipe_output(self):
return self._compute_baseline_output()
@pytest.fixture(scope="function")
def tmpdirname(self, tmp_path_factory):
return tmp_path_factory.mktemp("tmp")
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
@@ -211,10 +218,6 @@ class PeftLoraLoaderMixinTests:
return pipeline_components, text_lora_config, denoiser_lora_config
@property
def output_shape(self):
raise NotImplementedError
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
@@ -235,6 +238,23 @@ class PeftLoraLoaderMixinTests:
return (noise, input_ids, pipeline_inputs)
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
if denoiser_lora_config is not None:
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
else:
denoiser = None
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
return pipe, denoiser
def _compute_baseline_output(self):
components, _, _ = self.get_dummy_components(self.scheduler_cls)
pipe = self.pipeline_class(**components)
@@ -286,23 +306,6 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
if denoiser_lora_config is not None:
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
else:
denoiser = None
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
return pipe, denoiser
def test_simple_inference(self, base_pipe_output):
"""
Tests a simple inference and makes sure it works as expected
@@ -375,7 +378,7 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self):
def test_low_cpu_mem_usage_with_loading(self, tmpdirname):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -386,34 +389,31 @@ class PeftLoraLoaderMixinTests:
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results."
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results."
)
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001
), "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
)
def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output):
"""
@@ -498,7 +498,7 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output"
)
def test_simple_inference_with_text_lora_save_load(self):
def test_simple_inference_with_text_lora_save_load(self, tmpdirname):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
@@ -511,16 +511,13 @@ class PeftLoraLoaderMixinTests:
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
@@ -577,7 +574,7 @@ class PeftLoraLoaderMixinTests:
"Removing adapters should change the output"
)
def test_simple_inference_save_pretrained_with_text_lora(self):
def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
@@ -590,10 +587,9 @@ class PeftLoraLoaderMixinTests:
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), (
@@ -610,7 +606,7 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results."
)
def test_simple_inference_with_text_denoiser_lora_save_load(self):
def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
@@ -623,15 +619,12 @@ class PeftLoraLoaderMixinTests:
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
@@ -1486,7 +1479,7 @@ class PeftLoraLoaderMixinTests:
"DoRA lora should change the output",
)
def test_missing_keys_warning(self):
def test_missing_keys_warning(self, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -1496,15 +1489,12 @@ class PeftLoraLoaderMixinTests:
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
@@ -1516,7 +1506,7 @@ class PeftLoraLoaderMixinTests:
component = list({k.split(".")[0] for k in state_dict})[0]
assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
def test_unexpected_keys_warning(self, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -1526,15 +1516,12 @@ class PeftLoraLoaderMixinTests:
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
@@ -1616,7 +1603,7 @@ class PeftLoraLoaderMixinTests:
)
assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
def test_set_adapters_match_attention_kwargs(self, base_pipe_output):
def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
@@ -1645,31 +1632,28 @@ class PeftLoraLoaderMixinTests:
"Lora + scale should match the output of `set_adapters()`."
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Lora + scale should change the output"
)
assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results as attention_kwargs."
)
assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results as set_adapters()."
)
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Lora + scale should change the output"
)
assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results as attention_kwargs."
)
assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results as set_adapters()."
)
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self, base_pipe_output):
@@ -1806,7 +1790,7 @@ class PeftLoraLoaderMixinTests:
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.14.0")
def test_layerwise_casting_peft_input_autocast_denoiser(self):
def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname):
"""
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
@@ -1865,77 +1849,73 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
apply_layerwise_casting(
denoiser,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=patterns_to_check,
)
check_module(denoiser)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
apply_layerwise_casting(
denoiser,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=patterns_to_check,
)
check_module(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
if len(out) == 3:
(_, _, parsed_metadata) = out
elif len(out) == 2:
(_, parsed_metadata) = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True)
if len(out) == 3:
(_, _, parsed_metadata) = out
elif len(out) == 2:
(_, parsed_metadata) = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components).to(torch_device)
@@ -1946,18 +1926,15 @@ class PeftLoraLoaderMixinTests:
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), (
"Lora outputs should match."
)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdirname)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), "Lora outputs should match."
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
@@ -1977,7 +1954,7 @@ class PeftLoraLoaderMixinTests:
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_inference_load_delete_load_adapters(self, base_pipe_output):
def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname):
"""Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -2002,22 +1979,21 @@ class PeftLoraLoaderMixinTests:
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001)
assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001)
pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001)
assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001)
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001)
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001)
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
@@ -2031,59 +2007,56 @@ class PeftLoraLoaderMixinTests:
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_1 is not None
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_1 is not None
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_2 is not None
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_2 is not None
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_3 is not None
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_3 is not None
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001)
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname)
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
def test_lora_loading_model_cpu_offload(self, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
@@ -2096,18 +2069,15 @@ class PeftLoraLoaderMixinTests:
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001)