1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-12-11 11:17:28 +05:30
parent 14f51c51f0
commit 3e374fda38
15 changed files with 943 additions and 897 deletions

View File

@@ -35,8 +35,11 @@ from ..utils import (
deprecate,
get_adapter_name,
is_accelerate_available,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
@@ -64,6 +67,20 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
if is_torch_version(">=", "1.9.0"):
if (
is_peft_available()
and is_peft_version(">=", "0.13.1")
and is_transformers_available()
and is_transformers_version(">", "4.45.2")
):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
"""
@@ -475,6 +492,55 @@ def _func_optionally_disable_offloading(_pipeline):
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight
if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)
weight_on_cpu = False
if module.weight.device.type == "cpu":
weight_on_cpu = True
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.to(device) if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data
if weight_on_cpu:
module_weight = module_weight.cpu()
return module_weight
class LoraBaseMixin:
"""Utility class for handling LoRAs."""

View File

@@ -21,29 +21,24 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
)
from .lora_base import ( # noqa
_LOW_CPU_MEM_USAGE_DEFAULT_LORA,
LORA_WEIGHT_NAME,
LORA_WEIGHT_NAME_SAFE,
TEXT_ENCODER_NAME,
TRANSFORMER_NAME,
UNET_NAME,
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
_maybe_dequantize_weight_for_expanded_lora,
_pack_dict_with_prefix,
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
@@ -53,79 +48,12 @@ from .lora_conversion_utils import (
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_non_diffusers_z_image_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
if is_torch_version(">=", "1.9.0"):
if (
is_peft_available()
and is_peft_version(">=", "0.13.1")
and is_transformers_available()
and is_transformers_version(">", "4.45.2")
):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight
if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)
weight_on_cpu = False
if module.weight.device.type == "cpu":
weight_on_cpu = True
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.to(device) if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data
if weight_on_cpu:
module_weight = module_weight.cpu()
return module_weight
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
@@ -1483,802 +1411,15 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Specific to [`StableDiffusion3Pipeline`].
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
def __new__(cls, *args, **kwargs):
deprecation_message = (
"Importing `FluxLoraLoaderMixin` class like `from diffusers.loaders import FluxLoraLoaderMixin` is deprecated and will be removed in a future version. "
"Please use `from diffusers.pipelines.flux.lora_utils import FluxLoraLoaderMixin` instead. "
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
deprecate("FluxLoraLoaderMixin", "1.0.0", deprecation_message, standard_warn=False)
from ..pipelines.flux.lora_utils import FluxLoraLoaderMixin
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas or return_lora_metadata:
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
else:
return state_dict
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
has_lora_keys = any("lora" in key for key in state_dict.keys())
# Flux Control LoRAs also have norm keys
has_norm_keys = any(
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
transformer_lora_state_dict = {
k: state_dict.get(k)
for k in list(state_dict.keys())
if k.startswith(f"{self.transformer_name}.") and "lora" in k
}
transformer_norm_state_dict = {
k: state_dict.pop(k)
for k in list(state_dict.keys())
if k.startswith(f"{self.transformer_name}.")
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
}
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
has_param_with_expanded_shape = False
if len(transformer_lora_state_dict) > 0:
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
transformer, transformer_lora_state_dict, transformer_norm_state_dict
)
if has_param_with_expanded_shape:
logger.info(
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
if len(transformer_lora_state_dict) > 0:
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
for k in transformer_lora_state_dict:
state_dict.update({k: transformer_lora_state_dict[k]})
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if len(transformer_norm_state_dict) > 0:
transformer._transformer_norm_layers = self._load_norm_into_transformer(
transformer_norm_state_dict,
transformer=transformer,
discard_original_layers=False,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
metadata=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def _load_norm_into_transformer(
cls,
state_dict,
transformer,
prefix=None,
discard_original_layers=False,
) -> Dict[str, torch.Tensor]:
# Remove prefix if present
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Find invalid keys
transformer_state_dict = transformer.state_dict()
transformer_keys = set(transformer_state_dict.keys())
state_dict_keys = set(state_dict.keys())
extra_keys = list(state_dict_keys - transformer_keys)
if extra_keys:
logger.warning(
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
)
for key in extra_keys:
state_dict.pop(key)
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
overwritten_layers_state_dict = {}
if not discard_original_layers:
for key in state_dict.keys():
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
logger.info(
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
)
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
if unexpected_keys:
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
raise ValueError(
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
)
return overwritten_layers_state_dict
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if (
hasattr(transformer, "_transformer_norm_layers")
and isinstance(transformer._transformer_norm_layers, dict)
and len(transformer._transformer_norm_layers.keys()) > 0
):
logger.info(
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
)
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
> [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components, **kwargs)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
"""
Unloads the LoRA parameters.
Args:
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
to their original params. Refer to the [Flux
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
super().unload_lora_weights()
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()
for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)
tmp_state_dict = {"weight": current_param_weight}
if module_bias is not None:
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
setattr(parent_module, current_module_name, original_module)
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
transformer: torch.nn.Module,
lora_state_dict=None,
norm_state_dict=None,
prefix=None,
) -> bool:
"""
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
"""
state_dict = {}
if lora_state_dict is not None:
state_dict.update(lora_state_dict)
if norm_state_dict is not None:
state_dict.update(norm_state_dict)
# Remove prefix if present
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
if lora_A_weight_name not in state_dict:
continue
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]
# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight_shape) == (out_features, in_features):
continue
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
if debug_message:
logger.debug(debug_message)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
setattr(parent_module, current_module_name, expanded_module)
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
# For `unload_lora_weights()`.
# TODO: this could lead to more memory overhead if the number of overwritten params
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
overwritten_params[f"{current_module_name}.weight"] = module_weight
if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias
if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params
return has_param_with_shape_update
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."
lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight"
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
if expanded_module_names:
logger.info(
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)
return lora_state_dict
@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
return FluxLoraLoaderMixin(*args, **kwargs)
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially

View File

@@ -0,0 +1,839 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Dict, List, Optional, Union
import torch
from ...loaders.lora_base import (
_LOW_CPU_MEM_USAGE_DEFAULT_LORA,
TEXT_ENCODER_NAME,
TRANSFORMER_NAME,
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
_maybe_dequantize_weight_for_expanded_lora,
)
from ...loaders.lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
)
from ...utils import USE_PEFT_BACKEND, get_submodule_by_name, is_peft_version, logging
from ...utils.hub_utils import validate_hf_hub_args
logger = logging.get_logger(__name__)
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Specific to [`StableDiffusion3Pipeline`].
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas or return_lora_metadata:
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
else:
return state_dict
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
has_lora_keys = any("lora" in key for key in state_dict.keys())
# Flux Control LoRAs also have norm keys
has_norm_keys = any(
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
transformer_lora_state_dict = {
k: state_dict.get(k)
for k in list(state_dict.keys())
if k.startswith(f"{self.transformer_name}.") and "lora" in k
}
transformer_norm_state_dict = {
k: state_dict.pop(k)
for k in list(state_dict.keys())
if k.startswith(f"{self.transformer_name}.")
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
}
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
has_param_with_expanded_shape = False
if len(transformer_lora_state_dict) > 0:
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
transformer, transformer_lora_state_dict, transformer_norm_state_dict
)
if has_param_with_expanded_shape:
logger.info(
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
if len(transformer_lora_state_dict) > 0:
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
for k in transformer_lora_state_dict:
state_dict.update({k: transformer_lora_state_dict[k]})
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if len(transformer_norm_state_dict) > 0:
transformer._transformer_norm_layers = self._load_norm_into_transformer(
transformer_norm_state_dict,
transformer=transformer,
discard_original_layers=False,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
metadata=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def _load_norm_into_transformer(
cls,
state_dict,
transformer,
prefix=None,
discard_original_layers=False,
) -> Dict[str, torch.Tensor]:
# Remove prefix if present
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Find invalid keys
transformer_state_dict = transformer.state_dict()
transformer_keys = set(transformer_state_dict.keys())
state_dict_keys = set(state_dict.keys())
extra_keys = list(state_dict_keys - transformer_keys)
if extra_keys:
logger.warning(
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
)
for key in extra_keys:
state_dict.pop(key)
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
overwritten_layers_state_dict = {}
if not discard_original_layers:
for key in state_dict.keys():
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
logger.info(
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
)
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
if unexpected_keys:
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
raise ValueError(
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
)
return overwritten_layers_state_dict
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if (
hasattr(transformer, "_transformer_norm_layers")
and isinstance(transformer._transformer_norm_layers, dict)
and len(transformer._transformer_norm_layers.keys()) > 0
):
logger.info(
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
)
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
> [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components, **kwargs)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
"""
Unloads the LoRA parameters.
Args:
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
to their original params. Refer to the [Flux
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
super().unload_lora_weights()
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()
for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)
tmp_state_dict = {"weight": current_param_weight}
if module_bias is not None:
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
setattr(parent_module, current_module_name, original_module)
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
transformer: torch.nn.Module,
lora_state_dict=None,
norm_state_dict=None,
prefix=None,
) -> bool:
"""
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
"""
state_dict = {}
if lora_state_dict is not None:
state_dict.update(lora_state_dict)
if norm_state_dict is not None:
state_dict.update(norm_state_dict)
# Remove prefix if present
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
if lora_A_weight_name not in state_dict:
continue
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]
# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight_shape) == (out_features, in_features):
continue
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
if debug_message:
logger.debug(debug_message)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
setattr(parent_module, current_module_name, expanded_module)
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
# For `unload_lora_weights()`.
# TODO: this could lead to more memory overhead if the number of overwritten params
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
overwritten_params[f"{current_module_name}.weight"] = module_weight
if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias
if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params
return has_param_with_shape_update
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."
lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight"
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
if expanded_module_names:
logger.info(
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)
return lora_state_dict
@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict

View File

@@ -26,16 +26,13 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -19,17 +19,14 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -19,13 +19,14 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -24,17 +24,14 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import (
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
)
from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -26,7 +26,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin
from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -34,6 +34,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -10,7 +10,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -18,6 +18,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -11,7 +11,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -19,6 +19,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -19,13 +19,14 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -26,13 +26,14 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin
from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -27,7 +27,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin
from ...loaders import FluxIPAdapterMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -38,6 +38,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -26,12 +26,13 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput

View File

@@ -16,12 +16,13 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .lora_utils import FluxLoraLoaderMixin
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput