diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 3d75a7d875..e53ac40a62 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -35,8 +35,11 @@ from ..utils import ( deprecate, get_adapter_name, is_accelerate_available, + is_bitsandbytes_available, + is_gguf_available, is_peft_available, is_peft_version, + is_torch_version, is_transformers_available, is_transformers_version, logging, @@ -64,6 +67,20 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): """ @@ -475,6 +492,55 @@ def _func_optionally_disable_offloading(_pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) +def _maybe_dequantize_weight_for_expanded_lora(model, module): + if is_bitsandbytes_available(): + from ..quantizers.bitsandbytes import dequantize_bnb_weight + + if is_gguf_available(): + from ..quantizers.gguf.utils import dequantize_gguf_tensor + + is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params" + is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" + + if is_bnb_4bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." + ) + if is_bnb_8bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints." + ) + if is_gguf_quantized and not is_gguf_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." + ) + + weight_on_cpu = False + if module.weight.device.type == "cpu": + weight_on_cpu = True + + device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" + if is_bnb_4bit_quantized or is_bnb_8bit_quantized: + module_weight = dequantize_bnb_weight( + module.weight.to(device) if weight_on_cpu else module.weight, + state=module.weight.quant_state if is_bnb_4bit_quantized else module.state, + dtype=model.dtype, + ).data + elif is_gguf_quantized: + module_weight = dequantize_gguf_tensor( + module.weight.to(device) if weight_on_cpu else module.weight, + ) + module_weight = module_weight.to(model.dtype) + else: + module_weight = module.weight.data + + if weight_on_cpu: + module_weight = module_weight.cpu() + + return module_weight + + class LoraBaseMixin: """Utility class for handling LoRAs.""" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index bcbe54649f..6817c129d9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -21,29 +21,24 @@ from huggingface_hub.utils import validate_hf_hub_args from ..utils import ( USE_PEFT_BACKEND, deprecate, - get_submodule_by_name, - is_bitsandbytes_available, - is_gguf_available, - is_peft_available, is_peft_version, - is_torch_version, - is_transformers_available, - is_transformers_version, logging, ) from .lora_base import ( # noqa + _LOW_CPU_MEM_USAGE_DEFAULT_LORA, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, + TEXT_ENCODER_NAME, + TRANSFORMER_NAME, + UNET_NAME, LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, + _maybe_dequantize_weight_for_expanded_lora, _pack_dict_with_prefix, ) from .lora_conversion_utils import ( - _convert_bfl_flux_control_lora_to_diffusers, - _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, - _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, @@ -53,79 +48,12 @@ from .lora_conversion_utils import ( _convert_non_diffusers_qwen_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_non_diffusers_z_image_lora_to_diffusers, - _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) -_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False -if is_torch_version(">=", "1.9.0"): - if ( - is_peft_available() - and is_peft_version(">=", "0.13.1") - and is_transformers_available() - and is_transformers_version(">", "4.45.2") - ): - _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True - - logger = logging.get_logger(__name__) -TEXT_ENCODER_NAME = "text_encoder" -UNET_NAME = "unet" -TRANSFORMER_NAME = "transformer" - -_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} - - -def _maybe_dequantize_weight_for_expanded_lora(model, module): - if is_bitsandbytes_available(): - from ..quantizers.bitsandbytes import dequantize_bnb_weight - - if is_gguf_available(): - from ..quantizers.gguf.utils import dequantize_gguf_tensor - - is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" - is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params" - is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" - - if is_bnb_4bit_quantized and not is_bitsandbytes_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." - ) - if is_bnb_8bit_quantized and not is_bitsandbytes_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints." - ) - if is_gguf_quantized and not is_gguf_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." - ) - - weight_on_cpu = False - if module.weight.device.type == "cpu": - weight_on_cpu = True - - device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" - if is_bnb_4bit_quantized or is_bnb_8bit_quantized: - module_weight = dequantize_bnb_weight( - module.weight.to(device) if weight_on_cpu else module.weight, - state=module.weight.quant_state if is_bnb_4bit_quantized else module.state, - dtype=model.dtype, - ).data - elif is_gguf_quantized: - module_weight = dequantize_gguf_tensor( - module.weight.to(device) if weight_on_cpu else module.weight, - ) - module_weight = module_weight.to(model.dtype) - else: - module_weight = module.weight.data - - if weight_on_cpu: - module_weight = module_weight.cpu() - - return module_weight - class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" @@ -1483,802 +1411,15 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): class FluxLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`FluxTransformer2DModel`], - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - - Specific to [`StableDiffusion3Pipeline`]. - """ - - _lora_loadable_modules = ["transformer", "text_encoder"] - transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME - _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - return_alphas: bool = False, - **kwargs, - ): - r""" - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - - state_dict, metadata = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "Importing `FluxLoraLoaderMixin` class like `from diffusers.loaders import FluxLoraLoaderMixin` is deprecated and will be removed in a future version. " + "Please use `from diffusers.pipelines.flux.lora_utils import FluxLoraLoaderMixin` instead. " ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + deprecate("FluxLoraLoaderMixin", "1.0.0", deprecation_message, standard_warn=False) + from ..pipelines.flux.lora_utils import FluxLoraLoaderMixin - # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) - if is_kohya: - state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) - # Kohya already takes care of scaling the LoRA parameters with alpha. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - is_xlabs = any("processor" in k for k in state_dict) - if is_xlabs: - state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) - # xlabs doesn't use `alpha`. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - is_bfl_control = any("query_norm.scale" in k for k in state_dict) - if is_bfl_control: - state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - is_fal_kontext = any("base_model" in k for k in state_dict) - if is_fal_kontext: - state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - # For state dicts like - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA - keys = list(state_dict.keys()) - network_alphas = {} - for k in keys: - if "alpha" in k: - alpha_value = state_dict.get(k) - if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( - alpha_value, float - ): - network_alphas[k] = state_dict.pop(k) - else: - raise ValueError( - f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." - ) - - if return_alphas or return_lora_metadata: - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=network_alphas, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - else: - return state_dict - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, network_alphas, metadata = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs - ) - - has_lora_keys = any("lora" in key for key in state_dict.keys()) - - # Flux Control LoRAs also have norm keys - has_norm_keys = any( - norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys - ) - - if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") - - transformer_lora_state_dict = { - k: state_dict.get(k) - for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") and "lora" in k - } - transformer_norm_state_dict = { - k: state_dict.pop(k) - for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") - and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) - } - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - has_param_with_expanded_shape = False - if len(transformer_lora_state_dict) > 0: - has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) - - if has_param_with_expanded_shape: - logger.info( - "The LoRA weights contain parameters that have different shapes that expected by the transformer. " - "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " - "To get a comprehensive list of parameter names that were modified, enable debug logging." - ) - if len(transformer_lora_state_dict) > 0: - transformer_lora_state_dict = self._maybe_expand_lora_state_dict( - transformer=transformer, lora_state_dict=transformer_lora_state_dict - ) - for k in transformer_lora_state_dict: - state_dict.update({k: transformer_lora_state_dict[k]}) - - self.load_lora_into_transformer( - state_dict, - network_alphas=network_alphas, - transformer=transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - if len(transformer_norm_state_dict) > 0: - transformer._transformer_norm_layers = self._load_norm_into_transformer( - transformer_norm_state_dict, - transformer=transformer, - discard_original_layers=False, - ) - - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def load_lora_into_transformer( - cls, - state_dict, - network_alphas, - transformer, - adapter_name=None, - metadata=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. - """ - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def _load_norm_into_transformer( - cls, - state_dict, - transformer, - prefix=None, - discard_original_layers=False, - ) -> Dict[str, torch.Tensor]: - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) - - # Find invalid keys - transformer_state_dict = transformer.state_dict() - transformer_keys = set(transformer_state_dict.keys()) - state_dict_keys = set(state_dict.keys()) - extra_keys = list(state_dict_keys - transformer_keys) - - if extra_keys: - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." - ) - - for key in extra_keys: - state_dict.pop(key) - - # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers_state_dict = {} - if not discard_original_layers: - for key in state_dict.keys(): - overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() - - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " - 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " - "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." - ) - - # We can't load with strict=True because the current state_dict does not contain all the transformer keys - incompatible_keys = transformer.load_state_dict(state_dict, strict=False) - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - - # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. - if unexpected_keys: - if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): - raise ValueError( - f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." - ) - - return overwritten_layers_state_dict - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - transformer_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. - """ - lora_layers = {} - lora_metadata = {} - - if transformer_lora_layers: - lora_layers[cls.transformer_name] = transformer_lora_layers - lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - - if text_encoder_lora_layers: - lora_layers[cls.text_encoder_name] = text_encoder_lora_layers - lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - - if not lora_layers: - raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - - cls._save_lora_weights( - save_directory=save_directory, - lora_layers=lora_layers, - lora_metadata=lora_metadata, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. - """ - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if ( - hasattr(transformer, "_transformer_norm_layers") - and isinstance(transformer._transformer_norm_layers, dict) - and len(transformer._transformer_norm_layers.keys()) > 0 - ): - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " - "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." - ) - - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - > [!WARNING] > This is an experimental API. - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - """ - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: - transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - - super().unfuse_lora(components=components, **kwargs) - - # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. - def unload_lora_weights(self, reset_to_overwritten_params=False): - """ - Unloads the LoRA parameters. - - Args: - reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules - to their original params. Refer to the [Flux - documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the LoRA parameters. - >>> pipeline.unload_lora_weights() - >>> ... - ``` - """ - super().unload_lora_weights() - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: - transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - transformer._transformer_norm_layers = None - - if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: - overwritten_params = transformer._overwritten_params - module_names = set() - - for param_name in overwritten_params: - if param_name.endswith(".weight"): - module_names.add(param_name.replace(".weight", "")) - - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear) and name in module_names: - module_weight = module.weight.data - module_bias = module.bias.data if module.bias is not None else None - bias = module_bias is not None - - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) - - current_param_weight = overwritten_params[f"{name}.weight"] - in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] - with torch.device("meta"): - original_module = torch.nn.Linear( - in_features, - out_features, - bias=bias, - dtype=module_weight.dtype, - ) - - tmp_state_dict = {"weight": current_param_weight} - if module_bias is not None: - tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) - original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) - setattr(parent_module, current_module_name, original_module) - - del tmp_state_dict - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(current_param_weight.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info( - f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." - ) - - @classmethod - def _maybe_expand_transformer_param_shape_or_error_( - cls, - transformer: torch.nn.Module, - lora_state_dict=None, - norm_state_dict=None, - prefix=None, - ) -> bool: - """ - Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and - generalizes things a bit so that any parameter that needs expansion receives appropriate treatment. - """ - state_dict = {} - if lora_state_dict is not None: - state_dict.update(lora_state_dict) - if norm_state_dict is not None: - state_dict.update(norm_state_dict) - - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) - - # Expand transformer parameter shapes if they don't match lora - has_param_with_shape_update = False - overwritten_params = {} - - is_peft_loaded = getattr(transformer, "peft_config", None) is not None - is_quantized = hasattr(transformer, "hf_quantizer") - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear): - module_weight = module.weight.data - module_bias = module.bias.data if module.bias is not None else None - bias = module_bias is not None - - lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name - lora_A_weight_name = f"{lora_base_name}.lora_A.weight" - lora_B_weight_name = f"{lora_base_name}.lora_B.weight" - if lora_A_weight_name not in state_dict: - continue - - in_features = state_dict[lora_A_weight_name].shape[1] - out_features = state_dict[lora_B_weight_name].shape[0] - - # Model maybe loaded with different quantization schemes which may flatten the params. - # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models - # preserve weight shape. - module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) - - # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight_shape) == (out_features, in_features): - continue - - module_out_features, module_in_features = module_weight_shape - debug_message = "" - if in_features > module_in_features: - debug_message += ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if out_features > module_out_features: - debug_message += ( - ", and the number of output features will be " - f"expanded from {module_out_features} to {out_features}." - ) - else: - debug_message += "." - if debug_message: - logger.debug(debug_message) - - if out_features > module_out_features or in_features > module_in_features: - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) - - if is_quantized: - module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) - - # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. - with torch.device("meta"): - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. This is because only the input dimensions - # are changed while the output dimensions remain the same. The shape of the weight tensor - # is (out_features, in_features), while the shape of bias tensor is (out_features,), which - # explains the reason why only weights are expanded. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype - ) - slices = tuple(slice(0, dim) for dim in module_weight_shape) - new_weight[slices] = module_weight - tmp_state_dict = {"weight": new_weight} - if module_bias is not None: - tmp_state_dict["bias"] = module_bias - expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) - - setattr(parent_module, current_module_name, expanded_module) - - del tmp_state_dict - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info( - f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." - ) - - # For `unload_lora_weights()`. - # TODO: this could lead to more memory overhead if the number of overwritten params - # are large. Should be revisited later and tackled through a `discard_original_layers` arg. - overwritten_params[f"{current_module_name}.weight"] = module_weight - if module_bias is not None: - overwritten_params[f"{current_module_name}.bias"] = module_bias - - if len(overwritten_params) > 0: - transformer._overwritten_params = overwritten_params - - return has_param_with_shape_update - - @classmethod - def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): - expanded_module_names = set() - transformer_state_dict = transformer.state_dict() - prefix = f"{cls.transformer_name}." - - lora_module_names = [ - key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") - ] - lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] - lora_module_names = sorted(set(lora_module_names)) - transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) - unexpected_modules = set(lora_module_names) - set(transformer_module_names) - if unexpected_modules: - logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - - for k in lora_module_names: - if k in unexpected_modules: - continue - - base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" - if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict - else f"{k.replace(prefix, '')}.weight" - ) - base_weight_param = transformer_state_dict[base_param_name] - lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - - # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. - base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) - - if base_module_shape[1] > lora_A_param.shape[1]: - shape = (lora_A_param.shape[0], base_weight_param.shape[1]) - expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) - expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) - lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight - expanded_module_names.add(k) - elif base_module_shape[1] < lora_A_param.shape[1]: - raise NotImplementedError( - f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." - ) - - if expanded_module_names: - logger.info( - f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." - ) - - return lora_state_dict - - @staticmethod - def _calculate_module_shape( - model: "torch.nn.Module", - base_module: "torch.nn.Linear" = None, - base_weight_param_name: str = None, - ) -> "torch.Size": - def _get_weight_shape(weight: torch.Tensor): - if weight.__class__.__name__ == "Params4bit": - return weight.quant_state.shape - elif weight.__class__.__name__ == "GGUFParameter": - return weight.quant_shape - else: - return weight.shape - - if base_module is not None: - return _get_weight_shape(base_module.weight) - elif base_weight_param_name is not None: - if not base_weight_param_name.endswith(".weight"): - raise ValueError( - f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." - ) - module_path = base_weight_param_name.rsplit(".weight", 1)[0] - submodule = get_submodule_by_name(model, module_path) - return _get_weight_shape(submodule.weight) - - raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") - - @staticmethod - def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): - outputs = [state_dict] - if return_alphas: - outputs.append(alphas) - if return_metadata: - outputs.append(metadata) - return tuple(outputs) if (return_alphas or return_metadata) else state_dict + return FluxLoraLoaderMixin(*args, **kwargs) # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially diff --git a/src/diffusers/pipelines/flux/lora_utils.py b/src/diffusers/pipelines/flux/lora_utils.py new file mode 100644 index 0000000000..adf7a68cbb --- /dev/null +++ b/src/diffusers/pipelines/flux/lora_utils.py @@ -0,0 +1,839 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Dict, List, Optional, Union + +import torch + +from ...loaders.lora_base import ( + _LOW_CPU_MEM_USAGE_DEFAULT_LORA, + TEXT_ENCODER_NAME, + TRANSFORMER_NAME, + LoraBaseMixin, + _fetch_state_dict, + _load_lora_into_text_encoder, + _maybe_dequantize_weight_for_expanded_lora, +) +from ...loaders.lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, + _convert_fal_kontext_lora_to_diffusers, + _convert_kohya_flux_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, +) +from ...utils import USE_PEFT_BACKEND, get_submodule_by_name, is_peft_version, logging +from ...utils.hub_utils import validate_hf_hub_args + + +logger = logging.get_logger(__name__) + +_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} + + +class FluxLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`FluxTransformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + is_fal_kontext = any("base_model" in k for k in state_dict) + if is_fal_kontext: + state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + keys = list(state_dict.keys()) + network_alphas = {} + for k in keys: + if "alpha" in k: + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + network_alphas[k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) + + if return_alphas or return_lora_metadata: + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=network_alphas, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + else: + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) + + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): + raise ValueError("Invalid LoRA checkpoint.") + + transformer_lora_state_dict = { + k: state_dict.get(k) + for k in list(state_dict.keys()) + if k.startswith(f"{self.transformer_name}.") and "lora" in k + } + transformer_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if k.startswith(f"{self.transformer_name}.") + and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + has_param_with_expanded_shape = False + if len(transformer_lora_state_dict) > 0: + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + if len(transformer_lora_state_dict) > 0: + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) + for k in transformer_lora_state_dict: + state_dict.update({k: transformer_lora_state_dict[k]}) + + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + if len(transformer_norm_state_dict) > 0: + transformer._transformer_norm_layers = self._load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def _load_norm_into_transformer( + cls, + state_dict, + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + incompatible_keys = transformer.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) + + return overwritten_layers_state_dict + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if text_encoder_lora_layers: + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if ( + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 + ): + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + ) + + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + > [!WARNING] > This is an experimental API. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + + super().unfuse_lora(components=components, **kwargs) + + # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. + def unload_lora_weights(self, reset_to_overwritten_params=False): + """ + Unloads the LoRA parameters. + + Args: + reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules + to their original params. Refer to the [Flux + documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + super().unload_lora_weights() + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + transformer._transformer_norm_layers = None + + if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: + overwritten_params = transformer._overwritten_params + module_names = set() + + for param_name in overwritten_params: + if param_name.endswith(".weight"): + module_names.add(param_name.replace(".weight", "")) + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear) and name in module_names: + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + current_param_weight = overwritten_params[f"{name}.weight"] + in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] + with torch.device("meta"): + original_module = torch.nn.Linear( + in_features, + out_features, + bias=bias, + dtype=module_weight.dtype, + ) + + tmp_state_dict = {"weight": current_param_weight} + if module_bias is not None: + tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) + original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) + setattr(parent_module, current_module_name, original_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(current_param_weight.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + + @classmethod + def _maybe_expand_transformer_param_shape_or_error_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatment. + """ + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) + + # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + overwritten_params = {} + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + is_quantized = hasattr(transformer, "hf_quantizer") + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + # Model maybe loaded with different quantization schemes which may flatten the params. + # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models + # preserve weight shape. + module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) + + # This means there's no need for an expansion in the params, so we simply skip. + if tuple(module_weight_shape) == (out_features, in_features): + continue + + module_out_features, module_in_features = module_weight_shape + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" + ) + if out_features > module_out_features: + debug_message += ( + ", and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + else: + debug_message += "." + if debug_message: + logger.debug(debug_message) + + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + if is_quantized: + module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) + + # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight_shape) + new_weight[slices] = module_weight + tmp_state_dict = {"weight": new_weight} + if module_bias is not None: + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) + + setattr(parent_module, current_module_name, expanded_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + + # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + + return has_param_with_shape_update + + @classmethod + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + + for k in lora_module_names: + if k in unexpected_modules: + continue + + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" + if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + else f"{k.replace(prefix, '')}.weight" + ) + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] + + # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. + base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) + + if base_module_shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + elif base_module_shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." + ) + + if expanded_module_names: + logger.info( + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) + + return lora_state_dict + + @staticmethod + def _calculate_module_shape( + model: "torch.nn.Module", + base_module: "torch.nn.Linear" = None, + base_weight_param_name: str = None, + ) -> "torch.Size": + def _get_weight_shape(weight: torch.Tensor): + if weight.__class__.__name__ == "Params4bit": + return weight.quant_state.shape + elif weight.__class__.__name__ == "GGUFParameter": + return weight.quant_shape + else: + return weight.shape + + if base_module is not None: + return _get_weight_shape(base_module.weight) + elif base_weight_param_name is not None: + if not base_weight_param_name.endswith(".weight"): + raise ValueError( + f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." + ) + module_path = base_weight_param_name.rsplit(".weight", 1)[0] + submodule = get_submodule_by_name(model, module_path) + return _get_weight_shape(submodule.weight) + + raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + + @staticmethod + def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): + outputs = [state_dict] + if return_alphas: + outputs.append(alphas) + if return_metadata: + outputs.append(metadata) + return tuple(outputs) if (return_alphas or return_metadata) else state_dict diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 8d03ed9b27..6e4f0b9477 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -26,16 +26,13 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 6d04c21ee5..b9f349ff24 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -19,17 +19,14 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 068ab7132f..1dc4252344 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -19,13 +19,14 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 1bf2f343c3..fbec774275 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -24,17 +24,14 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FluxLoraLoaderMixin, - FromSingleFileMixin, - TextualInversionLoaderMixin, -) +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index b26f3dcd12..84679f8c40 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -26,7 +26,7 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -34,6 +34,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index f246c8be10..7fbfb2d7d4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -10,7 +10,7 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -18,6 +18,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index b59b312239..7eee351976 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -11,7 +11,7 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -19,6 +19,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index aadf8e6077..247fd8e083 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -19,13 +19,14 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index f8ba950883..283197a8b1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -26,13 +26,14 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index cedbbd3953..c4bedfbb16 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -27,7 +27,7 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin +from ...loaders import FluxIPAdapterMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -38,6 +38,7 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 29ddfa333d..7e7deb265f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -26,12 +26,13 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index b421afeb0d..19edb171eb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -16,12 +16,13 @@ from transformers import ( ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .lora_utils import FluxLoraLoaderMixin from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput