mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
tempfile is now a fixture.
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -119,7 +118,7 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
def test_simple_inference_save_pretrained(self, tmpdirname):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
@@ -131,11 +130,10 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -114,7 +113,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_alpha_in_state_dict(self):
|
||||
def test_with_alpha_in_state_dict(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -126,24 +125,23 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
if "transformer" in k and "to_k" in k and ("lora_A" in k):
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
if "transformer" in k and "to_k" in k and ("lora_A" in k):
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
@@ -156,7 +154,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
)
|
||||
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output):
|
||||
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -175,16 +173,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
|
||||
pipe.set_adapters(["one", "two"])
|
||||
@@ -200,7 +197,7 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output):
|
||||
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -217,16 +214,15 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
pipe.unload_lora_weights()
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
|
||||
pipe.set_adapters(["one", "two"])
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -163,7 +162,7 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules_wanvace(self, base_pipe_output):
|
||||
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname):
|
||||
exclude_module_name = "vace_blocks.0.proj_out"
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
@@ -183,30 +182,26 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
assert not any(exclude_module_name in k for k in loaded_state_dict)
|
||||
assert any("proj_out" in k for k in loaded_state_dict)
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert not any(exclude_module_name in k for k in loaded_state_dict)
|
||||
assert any("proj_out" in k for k in loaded_state_dict)
|
||||
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
|
||||
"LoRA should change outputs."
|
||||
)
|
||||
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_and_scale()
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
|
||||
"LoRA should change outputs."
|
||||
)
|
||||
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
@@ -122,10 +121,18 @@ class PeftLoraLoaderMixinTests:
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def base_pipe_output(self):
|
||||
return self._compute_baseline_output()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmpdirname(self, tmp_path_factory):
|
||||
return tmp_path_factory.mktemp("tmp")
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||
@@ -211,10 +218,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return pipeline_components, text_lora_config, denoiser_lora_config
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
@@ -235,6 +238,23 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return (noise, input_ids, pipeline_inputs)
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
if denoiser_lora_config is not None:
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
else:
|
||||
denoiser = None
|
||||
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
return pipe, denoiser
|
||||
|
||||
def _compute_baseline_output(self):
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -286,23 +306,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
if denoiser_lora_config is not None:
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
else:
|
||||
denoiser = None
|
||||
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
return pipe, denoiser
|
||||
|
||||
def test_simple_inference(self, base_pipe_output):
|
||||
"""
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
@@ -375,7 +378,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
@require_peft_version_greater("0.13.1")
|
||||
@require_transformers_version_greater("4.45.2")
|
||||
def test_low_cpu_mem_usage_with_loading(self):
|
||||
def test_low_cpu_mem_usage_with_loading(self, tmpdirname):
|
||||
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -386,34 +389,31 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
|
||||
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
|
||||
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(
|
||||
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001
|
||||
), "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
|
||||
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output):
|
||||
"""
|
||||
@@ -498,7 +498,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
def test_simple_inference_with_text_lora_save_load(self, tmpdirname):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
@@ -511,16 +511,13 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
modules_to_save = self._get_modules_to_save(pipe)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
@@ -577,7 +574,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Removing adapters should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
@@ -590,10 +587,9 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), (
|
||||
@@ -610,7 +606,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_save_load(self):
|
||||
def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
|
||||
"""
|
||||
@@ -623,15 +619,12 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
@@ -1486,7 +1479,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"DoRA lora should change the output",
|
||||
)
|
||||
|
||||
def test_missing_keys_warning(self):
|
||||
def test_missing_keys_warning(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -1496,15 +1489,12 @@ class PeftLoraLoaderMixinTests:
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
pipe.unload_lora_weights()
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
|
||||
|
||||
missing_key = [k for k in state_dict if "lora_A" in k][0]
|
||||
del state_dict[missing_key]
|
||||
@@ -1516,7 +1506,7 @@ class PeftLoraLoaderMixinTests:
|
||||
component = list({k.split(".")[0] for k in state_dict})[0]
|
||||
assert missing_key.replace(f"{component}.", "" in cap_logger.out.replace("default_0.", ""))
|
||||
|
||||
def test_unexpected_keys_warning(self):
|
||||
def test_unexpected_keys_warning(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -1526,15 +1516,12 @@ class PeftLoraLoaderMixinTests:
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
pipe.unload_lora_weights()
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
|
||||
|
||||
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
|
||||
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
|
||||
@@ -1616,7 +1603,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self, base_pipe_output):
|
||||
def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
@@ -1645,31 +1632,28 @@ class PeftLoraLoaderMixinTests:
|
||||
"Lora + scale should match the output of `set_adapters()`."
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
for module_name, module in modules_to_save.items():
|
||||
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
|
||||
|
||||
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Lora + scale should change the output"
|
||||
)
|
||||
assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results as attention_kwargs."
|
||||
)
|
||||
assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results as set_adapters()."
|
||||
)
|
||||
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Lora + scale should change the output"
|
||||
)
|
||||
assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results as attention_kwargs."
|
||||
)
|
||||
assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results as set_adapters()."
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self, base_pipe_output):
|
||||
@@ -1806,7 +1790,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.14.0")
|
||||
def test_layerwise_casting_peft_input_autocast_denoiser(self):
|
||||
def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname):
|
||||
"""
|
||||
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
|
||||
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
|
||||
@@ -1865,77 +1849,73 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
apply_layerwise_casting(
|
||||
denoiser,
|
||||
storage_dtype=storage_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
skip_modules_pattern=patterns_to_check,
|
||||
)
|
||||
check_module(denoiser)
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
apply_layerwise_casting(
|
||||
denoiser,
|
||||
storage_dtype=storage_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
skip_modules_pattern=patterns_to_check,
|
||||
)
|
||||
check_module(denoiser)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
|
||||
if len(out) == 3:
|
||||
(_, _, parsed_metadata) = out
|
||||
elif len(out) == 2:
|
||||
(_, parsed_metadata) = out
|
||||
denoiser_key = (
|
||||
f"{self.pipeline_class.transformer_name}"
|
||||
if self.transformer_kwargs is not None
|
||||
else f"{self.pipeline_class.unet_name}"
|
||||
)
|
||||
assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True)
|
||||
if len(out) == 3:
|
||||
(_, _, parsed_metadata) = out
|
||||
elif len(out) == 2:
|
||||
(_, parsed_metadata) = out
|
||||
denoiser_key = (
|
||||
f"{self.pipeline_class.transformer_name}"
|
||||
if self.transformer_kwargs is not None
|
||||
else f"{self.pipeline_class.unet_name}"
|
||||
)
|
||||
assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
|
||||
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
|
||||
)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_key = self.pipeline_class.text_encoder_name
|
||||
assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
|
||||
)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_key = self.pipeline_class.text_encoder_name
|
||||
assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
|
||||
)
|
||||
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_key = "text_encoder_2"
|
||||
assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
|
||||
)
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_key = "text_encoder_2"
|
||||
assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
|
||||
)
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
@@ -1946,18 +1926,15 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_lora, output_lora_pretrained, atol=0.001, rtol=0.001), "Lora outputs should match."
|
||||
|
||||
def test_lora_unload_add_adapter(self):
|
||||
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
|
||||
@@ -1977,7 +1954,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_inference_load_delete_load_adapters(self, base_pipe_output):
|
||||
def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname):
|
||||
"""Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -2002,22 +1979,21 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
pipe.delete_adapters(pipe.get_active_adapters()[0])
|
||||
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001)
|
||||
assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001)
|
||||
pipe.delete_adapters(pipe.get_active_adapters()[0])
|
||||
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(output_adapter_1, output_no_adapter, atol=0.001, rtol=0.001)
|
||||
assert np.allclose(base_pipe_output, output_no_adapter, atol=0.001, rtol=0.001)
|
||||
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001)
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_adapter_1, output_lora_loaded, atol=0.001, rtol=0.001)
|
||||
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
|
||||
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
|
||||
|
||||
onload_device = torch_device
|
||||
@@ -2031,59 +2007,56 @@ class PeftLoraLoaderMixinTests:
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
denoiser.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
denoiser.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_1 is not None
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_1 is not None
|
||||
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.unload_lora_weights()
|
||||
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_2 is not None
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.unload_lora_weights()
|
||||
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_2 is not None
|
||||
|
||||
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_3 is not None
|
||||
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
|
||||
assert group_offload_hook_3 is not None
|
||||
|
||||
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001)
|
||||
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_1, output_3, atol=0.001, rtol=0.001)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_lora_loading_model_cpu_offload(self):
|
||||
def test_lora_loading_model_cpu_offload(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -2096,18 +2069,15 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_lora, output_lora_loaded, atol=0.001, rtol=0.001)
|
||||
|
||||
Reference in New Issue
Block a user