mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[lora]feat: use exclude modules to loraconfig. (#11806)
* feat: use exclude modules to loraconfig. * version-guard. * tests and version guard. * remove print. * describe the test * more detailed warning message + shift to debug * update * update * update * remove test
This commit is contained in:
@@ -244,13 +244,20 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank,
|
||||
model_state_dict=self.state_dict(),
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
||||
module.set_scale(adapter_name, 1.0)
|
||||
|
||||
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
||||
def get_peft_kwargs(
|
||||
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
else:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
||||
# for now we know that the "bias" keys are only associated with `lora_B`.
|
||||
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
"use_dora": use_dora,
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None:
|
||||
|
||||
|
||||
def _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank_pattern_dict,
|
||||
is_unet: bool = True,
|
||||
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -306,7 +318,12 @@ def _create_lora_config(
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
|
||||
rank_pattern_dict,
|
||||
network_alpha_dict=network_alphas,
|
||||
peft_state_dict=state_dict,
|
||||
is_unet=is_unet,
|
||||
model_state_dict=model_state_dict,
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
||||
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
@@ -24,7 +24,11 @@ from diffusers import (
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -291,6 +292,20 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -2326,6 +2341,58 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
|
||||
Reference in New Issue
Block a user