mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
more fixtures.
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
@@ -26,7 +25,6 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -117,29 +115,6 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_save_pretrained(self, tmpdirname):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False), ("leaf_level", True)],
|
||||
|
||||
@@ -155,11 +155,6 @@ class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
# TODO(aryan): Fix the following test
|
||||
@pytest.mark.skip("This test fails with an error I haven't been able to debug yet.")
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -132,6 +132,14 @@ class PeftLoraLoaderMixinTests:
|
||||
def tmpdirname(self, tmp_path_factory):
|
||||
return tmp_path_factory.mktemp("tmp")
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipe(self):
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||
@@ -314,16 +322,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
assert base_pipe_output.shape == self.output_shape
|
||||
|
||||
def test_simple_inference_with_text_lora(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_lora(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
@@ -331,12 +335,9 @@ class PeftLoraLoaderMixinTests:
|
||||
assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
|
||||
@require_peft_version_greater("0.13.1")
|
||||
def test_low_cpu_mem_usage_with_injection(self):
|
||||
def test_low_cpu_mem_usage_with_injection(self, pipe):
|
||||
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
|
||||
@@ -380,13 +381,9 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
@require_peft_version_greater("0.13.1")
|
||||
@require_transformers_version_greater("4.45.2")
|
||||
def test_low_cpu_mem_usage_with_loading(self, tmpdirname):
|
||||
def test_low_cpu_mem_usage_with_loading(self, tmpdirname, pipe):
|
||||
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -417,17 +414,13 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
@@ -446,16 +439,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"Lora + 0 scale should lead to same result as no LoRA"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_fused(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_lora_fused(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
@@ -471,16 +460,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
@@ -498,14 +483,11 @@ class PeftLoraLoaderMixinTests:
|
||||
"Unloading lora should match the base pipe output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_save_load(self, tmpdirname):
|
||||
def test_simple_inference_with_text_lora_save_load(self, tmpdirname, pipe):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -527,13 +509,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_partial_text_lora(self, base_pipe_output):
|
||||
def test_simple_inference_with_partial_text_lora(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
text_lora_config = LoraConfig(
|
||||
r=4,
|
||||
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
|
||||
@@ -542,9 +523,6 @@ class PeftLoraLoaderMixinTests:
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -574,15 +552,11 @@ class PeftLoraLoaderMixinTests:
|
||||
"Removing adapters should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname):
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname, pipe):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, _ = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -606,15 +580,11 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname):
|
||||
def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname, pipe):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -634,17 +604,13 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
@@ -667,16 +633,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"The scaling parameter has not been correctly restored!"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected - with unet
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
@@ -694,16 +656,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
@@ -723,17 +681,13 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(
|
||||
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
|
||||
self, pipe, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
|
||||
):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
@@ -840,15 +794,10 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
def test_multiple_wrong_adapter_name_raises_error(self, pipe):
|
||||
adapter_name = "adapter-1"
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
@@ -864,16 +813,12 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
one adapter and set different weights for different blocks (i.e. block lora)
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -911,17 +856,14 @@ class PeftLoraLoaderMixinTests:
|
||||
"output with no lora and output with lora disabled should give same results"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
multiple adapters and set different weights for different blocks (i.e. block lora)
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
@@ -967,7 +909,7 @@ class PeftLoraLoaderMixinTests:
|
||||
with pytest.raises(ValueError):
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self, pipe):
|
||||
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
|
||||
|
||||
def updown_options(blocks_with_tf, layers_per_block, value):
|
||||
@@ -1024,10 +966,7 @@ class PeftLoraLoaderMixinTests:
|
||||
opts.append(opt)
|
||||
return opts
|
||||
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
@@ -1045,16 +984,12 @@ class PeftLoraLoaderMixinTests:
|
||||
del scale_dict["text_encoder_2"]
|
||||
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
|
||||
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
multiple adapters and set/delete them
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1121,16 +1056,12 @@ class PeftLoraLoaderMixinTests:
|
||||
"output with no lora and output with lora disabled should give same results"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output):
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
multiple adapters and set them
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1186,12 +1117,8 @@ class PeftLoraLoaderMixinTests:
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_lora_fuse_nan(self, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1235,17 +1162,12 @@ class PeftLoraLoaderMixinTests:
|
||||
out = pipe(**inputs)[0]
|
||||
assert np.isnan(out).all()
|
||||
|
||||
def test_get_adapters(self):
|
||||
def test_get_adapters(self, pipe):
|
||||
"""
|
||||
Tests a simple usecase where we attach multiple adapters and check if the results
|
||||
are the expected results
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
@@ -1263,15 +1185,12 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"]
|
||||
|
||||
def test_get_list_adapters(self):
|
||||
def test_get_list_adapters(self, pipe):
|
||||
"""
|
||||
Tests a simple usecase where we attach multiple adapters and check if the results
|
||||
are the expected results
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
|
||||
# 1.
|
||||
dicts_to_be_checked = {}
|
||||
@@ -1324,16 +1243,16 @@ class PeftLoraLoaderMixinTests:
|
||||
assert pipe.get_list_adapters() == dicts_to_be_checked
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
|
||||
self,
|
||||
pipe,
|
||||
expected_atol: float = 1e-3,
|
||||
expected_rtol: float = 1e-3,
|
||||
):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1443,12 +1362,8 @@ class PeftLoraLoaderMixinTests:
|
||||
"LoRA should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_dora(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_simple_inference_with_dora(self, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -1460,12 +1375,8 @@ class PeftLoraLoaderMixinTests:
|
||||
"DoRA lora should change the output"
|
||||
)
|
||||
|
||||
def test_missing_keys_warning(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_missing_keys_warning(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
@@ -1487,11 +1398,8 @@ class PeftLoraLoaderMixinTests:
|
||||
component = list({k.split(".")[0] for k in state_dict})[0]
|
||||
assert missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")
|
||||
|
||||
def test_unexpected_keys_warning(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_unexpected_keys_warning(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
@@ -1513,16 +1421,12 @@ class PeftLoraLoaderMixinTests:
|
||||
assert ".diffusers_cat" in cap_logger.out
|
||||
|
||||
@pytest.mark.skip("This is failing for now - need to investigate")
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self, pipe):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -1533,29 +1437,19 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_modify_padding_mode(self):
|
||||
def test_modify_padding_mode(self, pipe):
|
||||
def set_pad_mode(network, mode="circular"):
|
||||
for _, module in network.named_modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.padding_mode = mode
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_pad_mode = "circular"
|
||||
set_pad_mode(pipe.vae, _pad_mode)
|
||||
set_pad_mode(pipe.unet, _pad_mode)
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_logs_info_when_no_lora_keys_found(self, base_pipe_output):
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_logs_info_when_no_lora_keys_found(self, base_pipe_output, pipe):
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
@@ -1584,16 +1478,11 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname):
|
||||
def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname, pipe):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
lora_scale = 0.5
|
||||
@@ -1636,12 +1525,8 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self, base_pipe_output):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_lora_B_bias(self, base_pipe_output, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
bias_values = {}
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, module in denoiser.named_modules():
|
||||
@@ -1670,12 +1555,8 @@ class PeftLoraLoaderMixinTests:
|
||||
assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)
|
||||
assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self, base_pipe_output):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self, base_pipe_output, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
@@ -1852,9 +1733,8 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@pytest.mark.parametrize("lora_alpha", [4, 8, 16])
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components)
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
@@ -1895,10 +1775,8 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("lora_alpha", [4, 8, 16])
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
@@ -1916,11 +1794,9 @@ class PeftLoraLoaderMixinTests:
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
|
||||
|
||||
def test_lora_unload_add_adapter(self):
|
||||
def test_lora_unload_add_adapter(self, pipe):
|
||||
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
@@ -1934,13 +1810,9 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname):
|
||||
def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname, pipe):
|
||||
"""Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."""
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1973,15 +1845,12 @@ class PeftLoraLoaderMixinTests:
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
|
||||
|
||||
onload_device = torch_device
|
||||
offload_device = torch.device("cpu")
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
@@ -1995,6 +1864,7 @@ class PeftLoraLoaderMixinTests:
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
@@ -2032,19 +1902,16 @@ class PeftLoraLoaderMixinTests:
|
||||
[("block_level", True), ("leaf_level", False), ("leaf_level", True)],
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname)
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
|
||||
@pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")
|
||||
def test_lora_loading_model_cpu_offload(self, tmpdirname):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
def test_lora_loading_model_cpu_offload(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
@@ -2055,6 +1922,7 @@ class PeftLoraLoaderMixinTests:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
|
||||
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
|
||||
Reference in New Issue
Block a user