1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA depcrecation] LoRA depcrecation trilogy (#6450)

* edebug

* debug

* more debug

* more more debug

* remove tests for LoRAAttnProcessors.

* rename
This commit is contained in:
Sayak Paul
2024-01-05 15:48:20 +05:30
committed by GitHub
parent 2fada8dc1b
commit 0a0bb526aa

View File

@@ -1496,7 +1496,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_lora_processors(self):
def test_lora_at_different_scales(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1514,9 +1514,6 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
model.load_attn_procs(lora_params)
model.to(torch_device)
# test that attn processors can be set to itself
model.set_attn_processor(model.attn_processors)
with torch.no_grad():
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
@@ -1595,7 +1592,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
@deprecate_after_peft_backend
class UNet3DConditionModelTests(unittest.TestCase):
class UNet3DConditionLoRAModelTests(unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
@@ -1638,7 +1635,7 @@ class UNet3DConditionModelTests(unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_lora_processors(self):
def test_lora_at_different_scales(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
@@ -1655,9 +1652,6 @@ class UNet3DConditionModelTests(unittest.TestCase):
model.load_attn_procs(unet_lora_params)
model.to(torch_device)
# test that attn processors can be set to itself
model.set_attn_processor(model.attn_processors)
with torch.no_grad():
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample