mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'chroma-fork' into chroma-final
This commit is contained in:
@@ -283,6 +283,8 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
title: ChromaTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
@@ -405,6 +407,8 @@
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/cogview3
|
||||
|
||||
@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
|
||||
@@ -159,10 +159,7 @@ class IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -45,6 +46,7 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
@@ -206,6 +209,7 @@ def _fetch_state_dict(
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
metadata=None,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
@@ -236,11 +240,14 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
metadata = _load_sft_state_dict_metadata(model_file)
|
||||
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
metadata = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
@@ -261,10 +268,11 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
metadata = None
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
return state_dict, metadata
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _pack_dict_with_prefix(state_dict, prefix):
|
||||
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
|
||||
return sd_with_prefix
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
@@ -889,8 +914,7 @@ class LoraBaseMixin:
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
return _pack_dict_with_prefix(layers_weights, prefix)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
@@ -900,16 +924,32 @@ class LoraBaseMixin:
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if lora_adapter_metadata and not safe_serialization:
|
||||
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
|
||||
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
|
||||
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
|
||||
lora_adapter_metadata, indent=2, sort_keys=True
|
||||
)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
if diff_keys:
|
||||
for diff_k in diff_keys:
|
||||
param = original_state_dict[diff_k]
|
||||
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
|
||||
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
|
||||
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
|
||||
# is okay to ignore because they do not affect the model output in a significant manner.
|
||||
threshold = 1.6e-2
|
||||
absdiff = param.abs().max() - param.abs().min()
|
||||
all_zero = torch.all(param == 0).item()
|
||||
if all_zero:
|
||||
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
|
||||
all_absdiff_lower_than_threshold = absdiff < threshold
|
||||
if all_zero or all_absdiff_lower_than_threshold:
|
||||
logger.debug(
|
||||
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
|
||||
)
|
||||
original_state_dict.pop(diff_k)
|
||||
|
||||
# For the `diff_b` keys, we treat them as lora_bias.
|
||||
@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.diff_b"
|
||||
@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
# Remaining.
|
||||
if original_state_dict:
|
||||
if any("time_projection" in k for k in original_state_dict):
|
||||
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
|
||||
f"time_projection.1.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
|
||||
f"time_projection.1.{lora_up_key}.weight"
|
||||
)
|
||||
original_key = f"time_projection.1.{lora_down_key}.weight"
|
||||
converted_key = "condition_embedder.time_proj.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"time_projection.1.{lora_up_key}.weight"
|
||||
converted_key = "condition_embedder.time_proj.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if "time_projection.1.diff_b" in original_state_dict:
|
||||
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
|
||||
"time_projection.1.diff_b"
|
||||
@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
|
||||
)
|
||||
|
||||
for img_ours, img_theirs in [
|
||||
("ff.net.0.proj", "img_emb.proj.1"),
|
||||
("ff.net.2", "img_emb.proj.3"),
|
||||
]:
|
||||
original_key = f"{img_theirs}.{lora_down_key}.weight"
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"{img_theirs}.{lora_up_key}.weight"
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
diff = all(".diff" in k for k in original_state_dict)
|
||||
if diff:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -185,6 +186,7 @@ class PeftAdapterMixin:
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
metadata: TODO
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -202,6 +204,7 @@ class PeftAdapterMixin:
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
metadata = kwargs.pop("metadata", None)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
@@ -209,12 +212,9 @@ class PeftAdapterMixin:
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -227,12 +227,17 @@ class PeftAdapterMixin:
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
metadata=metadata,
|
||||
)
|
||||
if network_alphas is not None and prefix is None:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
@@ -267,7 +272,12 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
|
||||
)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
@@ -290,7 +300,11 @@ class PeftAdapterMixin:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -445,17 +459,13 @@ class PeftAdapterMixin:
|
||||
underlying model has multiple adapters loaded.
|
||||
upcast_before_saving (`bool`, defaults to `False`):
|
||||
Whether to cast the underlying model to `torch.float32` before serialization.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
||||
"""
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -463,6 +473,8 @@ class PeftAdapterMixin:
|
||||
if adapter_name not in getattr(self, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
||||
|
||||
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
|
||||
|
||||
lora_layers_to_save = get_peft_model_state_dict(
|
||||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
||||
)
|
||||
@@ -472,7 +484,15 @@ class PeftAdapterMixin:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata is not None:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
@@ -485,7 +505,6 @@ class PeftAdapterMixin:
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# TODO: we could consider saving the `peft_config` as well.
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(lora_layers_to_save, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
|
||||
@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
|
||||
import enum
|
||||
import json
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
from .logging import get_logger
|
||||
@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
|
||||
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
|
||||
|
||||
return all(torch.all(param == 0).item() for param in state_dict.values())
|
||||
|
||||
|
||||
def _load_sft_state_dict_metadata(model_file: str):
|
||||
import safetensors.torch
|
||||
|
||||
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
return json.loads(raw) if raw else None
|
||||
|
||||
@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
|
||||
return distance
|
||||
|
||||
|
||||
def check_if_dicts_are_equal(dict1, dict2):
|
||||
dict1, dict2 = dict1.copy(), dict2.copy()
|
||||
|
||||
for key, value in dict1.items():
|
||||
if isinstance(value, set):
|
||||
dict1[key] = sorted(value)
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, set):
|
||||
dict2[key] = sorted(value)
|
||||
|
||||
for key in dict1:
|
||||
if key not in dict2:
|
||||
return False
|
||||
if dict1[key] != dict2[key]:
|
||||
return False
|
||||
|
||||
for key in dict2:
|
||||
if key not in dict1:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def print_tensor_test(
|
||||
tensor,
|
||||
limit_to_slices=None,
|
||||
|
||||
@@ -24,11 +24,7 @@ from diffusers import (
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@@ -22,6 +22,7 @@ from itertools import product
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -33,6 +34,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
check_if_dicts_are_equal,
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
|
||||
extracted = {
|
||||
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
|
||||
}
|
||||
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
|
||||
|
||||
|
||||
def initialize_dummy_state_dict(state_dict):
|
||||
if not all(v.device.type == "meta" for _, v in state_dict.items()):
|
||||
raise ValueError("`state_dict` has non-meta values.")
|
||||
@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||
if self.has_two_text_encoders and self.has_three_text_encoders:
|
||||
@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
|
||||
rank = 4
|
||||
lora_alpha = rank if lora_alpha is None else lora_alpha
|
||||
|
||||
torch.manual_seed(0)
|
||||
if self.unet_kwargs is not None:
|
||||
@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.text_encoder_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.denoiser_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
|
||||
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
|
||||
return state_dicts
|
||||
|
||||
def _get_lora_adapter_metadata(self, modules_to_save):
|
||||
metadatas = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
||||
return metadatas
|
||||
|
||||
def _get_modules_to_save(self, pipe, has_denoiser=False):
|
||||
modules_to_save = {}
|
||||
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
|
||||
@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
|
||||
scheduler_cls, lora_alpha=lora_alpha
|
||||
)
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
|
||||
if len(out) == 3:
|
||||
_, _, parsed_metadata = out
|
||||
elif len(out) == 2:
|
||||
_, parsed_metadata = out
|
||||
|
||||
denoiser_key = (
|
||||
f"{self.pipeline_class.transformer_name}"
|
||||
if self.transformer_kwargs is not None
|
||||
else f"{self.pipeline_class.unet_name}"
|
||||
)
|
||||
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
|
||||
)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_key = self.pipeline_class.text_encoder_name
|
||||
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
|
||||
)
|
||||
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_key = "text_encoder_2"
|
||||
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
|
||||
)
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
|
||||
scheduler_cls, lora_alpha=lora_alpha
|
||||
)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
check_if_dicts_are_equal,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_save_load_adapter(self, use_dora=False):
|
||||
import safetensors
|
||||
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
# Perturb the metadata in the state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with self.assertRaises(TypeError) as err_context:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_cpu_offload(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -125,7 +125,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"num_attention_heads": 192,
|
||||
"joint_attention_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
"approximator_num_channels": 8,
|
||||
|
||||
@@ -4,20 +4,10 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ChromaPipeline,
|
||||
ChromaTransformer2DModel,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
@@ -28,10 +18,12 @@ class ChromaPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
):
|
||||
pipeline_class = ChromaPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
|
||||
pipeline_class = ChromaPipeline
|
||||
params = frozenset(["prompt", "negative_prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
@@ -39,14 +31,6 @@ class ChromaPipelineFastTests(
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=False,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = ChromaTransformer2DModel(
|
||||
@@ -55,7 +39,7 @@ class ChromaPipelineFastTests(
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
num_attention_heads=192,
|
||||
joint_attention_dim=32,
|
||||
axes_dims_rope=[4, 4, 8],
|
||||
)
|
||||
Reference in New Issue
Block a user