From d1db4f853a8d5da0a4bc4112010bca8d900871ef Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Jun 2025 14:26:35 +0530 Subject: [PATCH 1/8] [LoRA ]fix flux lora loader when return_metadata is true for non-diffusers (#11716) * fix flux lora loader when return_metadata is true for non-diffusers * remove annotation --- src/diffusers/loaders/lora_pipeline.py | 46 ++++++++++++++++++++----- src/diffusers/loaders/peft.py | 4 ++- src/diffusers/utils/state_dict_utils.py | 7 ++-- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 27053623ee..8fdd8a88ed 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin): if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) # Kohya already takes care of scaling the LoRA parameters with alpha. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_xlabs = any("processor" in k for k in state_dict) if is_xlabs: state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) # xlabs doesn't use `alpha`. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA @@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ) if return_alphas or return_lora_metadata: - outputs = [state_dict] - if return_alphas: - outputs.append(network_alphas) - if return_lora_metadata: - outputs.append(metadata) - return tuple(outputs) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=network_alphas, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) else: return state_dict @@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin): raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + @staticmethod + def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): + outputs = [state_dict] + if return_alphas: + outputs.append(alphas) + if return_metadata: + outputs.append(metadata) + return tuple(outputs) if (return_alphas or return_metadata) else state_dict + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index e7a458f28e..6bb6e36936 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -187,7 +187,9 @@ class PeftAdapterMixin: Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - metadata: TODO + metadata: + LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to + initialize `LoraConfig`. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 498f7e566c..8e6078488a 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str): metadata = f.metadata() or {} metadata.pop("format", None) - raw = metadata.get(LORA_ADAPTER_METADATA_KEY) - return json.loads(raw) if raw else None + if metadata: + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + return json.loads(raw) if raw else None + else: + return None From f0dba33d82af991369806312f61ab4c6cb7a8dd1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Jun 2025 16:42:34 +0530 Subject: [PATCH 2/8] [training] show how metadata stuff should be incorporated in training scripts. (#11707) * show how metadata stuff should be incorporated in training scripts. * typing * fix --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- .../dreambooth/test_dreambooth_lora_flux.py | 45 +++++++++++++++++++ .../dreambooth/train_dreambooth_lora_flux.py | 22 ++++++--- src/diffusers/training_utils.py | 8 ++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index a76825e294..837a537b5a 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import sys @@ -20,6 +21,8 @@ import tempfile import safetensors +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -234,3 +237,45 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate): run_command(self._launch_args + resume_run_args) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 1caf9c62d7..9c529cbb92 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -27,7 +27,6 @@ from pathlib import Path import numpy as np import torch -import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -53,6 +52,7 @@ from diffusers import ( ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _collate_lora_metadata, _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, @@ -358,7 +358,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) - + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument( @@ -1238,7 +1243,7 @@ def main(args): # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=target_modules, @@ -1247,7 +1252,7 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], @@ -1264,12 +1269,14 @@ def main(args): if accelerator.is_main_process: transformer_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None - + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["text_encoder"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1280,6 +1287,7 @@ def main(args): output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1889,16 +1897,19 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + modules_to_save = {} transformer = unwrap_model(transformer) if args.upcast_before_saving: transformer.to(torch.float32) else: transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + modules_to_save["text_encoder"] = text_encoder_one else: text_encoder_lora_layers = None @@ -1906,6 +1917,7 @@ def main(args): save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, + **_collate_lora_metadata(modules_to_save), ) # Final inference diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 43bf0010d7..bc30411d87 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder( set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") +def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]: + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + + def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, From 81426b0f19b529b3f002227c9c7f4fa883584a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Mon, 16 Jun 2025 20:47:00 +0200 Subject: [PATCH 3/8] Fix misleading comment (#11722) --- scripts/convert_stable_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py index b33c8b0608..757d47a316 100644 --- a/scripts/convert_stable_audio.py +++ b/scripts/convert_stable_audio.py @@ -1,4 +1,4 @@ -# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +# Run this script to convert the Stable Audio model weights to a diffusers pipeline. import argparse import json import os From 9b834f871029f83391f9ae1c262f4b6882c594a5 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 16 Jun 2025 21:25:05 +0200 Subject: [PATCH 4/8] Add Pruna optimization framework documentation (#11688) * Add Pruna optimization framework documentation - Introduced a new section for Pruna in the table of contents. - Added comprehensive documentation for Pruna, detailing its optimization techniques, installation instructions, and examples for optimizing and evaluating models * Enhance Pruna documentation with image alt text and code block formatting - Added alt text to images for better accessibility and context. - Changed code block syntax from diff to python for improved clarity. * Add installation section to Pruna documentation - Introduced a new installation section in the Pruna documentation to guide users on how to install the framework. - Enhanced the overall clarity and usability of the documentation for new users. * Update pruna.md * Update pruna.md * Update Pruna documentation for model optimization and evaluation - Changed section titles for consistency and clarity, from "Optimizing models" to "Optimize models" and "Evaluating and benchmarking optimized models" to "Evaluate and benchmark models". - Enhanced descriptions to clarify the use of `diffusers` models and the evaluation process. - Added a new example for evaluating standalone `diffusers` models. - Updated references and links for better navigation within the documentation. * Refactor Pruna documentation for clarity and consistency - Removed outdated references to FLUX-juiced and streamlined the explanation of benchmarking. - Enhanced the description of evaluating standalone `diffusers` models. - Cleaned up code examples by removing unnecessary imports and comments for better readability. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Enhance Pruna documentation with new examples and clarifications - Added an image to illustrate the optimization process. - Updated the explanation for sharing and loading optimized models on the Hugging Face Hub. - Clarified the evaluation process for optimized models using the EvaluationAgent. - Improved descriptions for defining metrics and evaluating standalone diffusers models. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/optimization/pruna.md | 187 +++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 docs/source/en/optimization/pruna.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5492dff04c..0530e11ac2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -180,6 +180,8 @@ title: Caching - local: optimization/memory title: Reduce memory usage + - local: optimization/pruna + title: Pruna - local: optimization/xformers title: xFormers - local: optimization/tome diff --git a/docs/source/en/optimization/pruna.md b/docs/source/en/optimization/pruna.md new file mode 100644 index 0000000000..56c1f3af59 --- /dev/null +++ b/docs/source/en/optimization/pruna.md @@ -0,0 +1,187 @@ +# Pruna + +[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below. + + +| Technique | Description | Speed | Memory | Quality | +|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:| +| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅ | ❌ | ➖ | +| `cacher` | Stores intermediate results of computations to speed up subsequent operations. | ✅ | ➖ | ➖ | +| `compiler` | Optimises the model with instructions for specific hardware. | ✅ | ➖ | ➖ | +| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. | ✅ | ✅ | ❌ | +| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. | ✅ | ✅ | ❌ | +| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅ | ✅ | ❌ | +| `recoverer` | Restores the performance of a model after compression. | ➖ | ➖ | ✅ | +| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | ➖ | ➖ | +| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ | + +✅ (improves), ➖ (approx. the same), ❌ (worsens) + +Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms). + +## Installation + +Install Pruna with the following command. + +```bash +pip install pruna +``` + + +## Optimize Diffusers models + +A broad range of optimization algorithms are supported for Diffusers models as shown below. + +
+ Overview of the supported optimization algorithms for diffusers models +
+ +The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) +with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality. + +> [!TIP] +> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example. + +
+ Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms +
+ +Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference. + +```python +import torch +from diffusers import FluxPipeline + +from pruna import PrunaModel, SmashConfig, smash + +# load the model +# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16 +).to("cuda") + +# define the configuration +smash_config = SmashConfig() +smash_config["factorizer"] = "qkv_diffusers" +smash_config["compiler"] = "torch_compile" +smash_config["torch_compile_target"] = "module_list" +smash_config["cacher"] = "fora" +smash_config["fora_interval"] = 2 + +# for the best results in terms of speed you can add these configs +# however they will increase your warmup time from 1.5 min to 10 min +# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs" +# smash_config["quantizer"] = "torchao" +# smash_config["torchao_quant_type"] = "fp8dq" +# smash_config["torchao_excluded_modules"] = "norm+embedding" + +# optimize the model +smashed_pipe = smash(pipe, smash_config) + +# run the model +smashed_pipe("a knitted purple prune").images[0] +``` + +
+ +
+ +After optimization, we can share and load the optimized model using the Hugging Face Hub. + +```python +# save the model +smashed_pipe.save_to_hub("/FLUX.1-dev-smashed") + +# load the model +smashed_pipe = PrunaModel.from_hub("/FLUX.1-dev-smashed") +``` + +## Evaluate and benchmark Diffusers models + +Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models. + +We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`. + + + + +We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`. + +```python +import torch +from diffusers import FluxPipeline + +from pruna import PrunaModel +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.evaluation_agent import EvaluationAgent +from pruna.evaluation.metrics import ( + ThroughputMetric, + TorchMetricWrapper, + TotalTimeMetric, +) +from pruna.evaluation.task import Task + +# define the device +device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + +# load the model +# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory +smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed") + +# Define the metrics +metrics = [ + TotalTimeMetric(n_iterations=20, n_warmup_iterations=5), + ThroughputMetric(n_iterations=20, n_warmup_iterations=5), + TorchMetricWrapper("clip"), +] + +# Define the datamodule +datamodule = PrunaDataModule.from_string("LAION256") +datamodule.limit_datasets(10) + +# Define the task and evaluation agent +task = Task(metrics, datamodule=datamodule, device=device) +eval_agent = EvaluationAgent(task) + +# Evaluate smashed model and offload it to CPU +smashed_pipe.move_to_device(device) +smashed_pipe_results = eval_agent.evaluate(smashed_pipe) +smashed_pipe.move_to_device("cpu") +``` + + + + +Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it. + +```python +import torch +from diffusers import FluxPipeline + +from pruna import PrunaModel + +# load the model +# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16 +).to("cpu") +wrapped_pipe = PrunaModel(model=pipe) +``` + + + + +Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started. + +> [!TIP] +> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space. + +## Reference + +- [Pruna](https://github.com/pruna-ai/pruna) +- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms) +- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) +- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html) + From 79bd7ecc7807c31edd7eee0cb136fd065334931e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Jun 2025 10:39:18 +0530 Subject: [PATCH 5/8] Support more Wan loras (VACE) (#11726) update --- .../loaders/lora_conversion_utils.py | 87 +++++++++++-------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 7bde2a00be..d797222e83 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k}) + block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")} + min_block = min(block_numbers) + max_block = max(block_numbers) + is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" @@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # For the `diff_b` keys, we treat them as lora_bias. # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias - for i in range(num_blocks): + for i in range(min_block, max_block + 1): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - ) - if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.diff_b" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): @@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): if original_key in original_state_dict: converted_state_dict[converted_key] = original_state_dict.pop(original_key) - if f"blocks.{i}.{o}.diff_b" in original_state_dict: - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop( - f"blocks.{i}.{o}.diff_b" - ) + original_key = f"blocks.{i}.{o}.diff_b" + converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # Remaining. if original_state_dict: From 1bc6f3dc0f21779480db70a4928d14282c0198ed Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 17 Jun 2025 12:19:27 +0300 Subject: [PATCH 6/8] [LoRA training] update metadata use for lora alpha + README (#11723) * lora alpha * Apply style fixes * Update examples/advanced_diffusion_training/README_flux.md Co-authored-by: Sayak Paul * fix readme format --------- Co-authored-by: github-actions[bot] Co-authored-by: Sayak Paul --- .../README_flux.md | 18 ++++++++ .../test_dreambooth_lora_flux_advanced.py | 45 +++++++++++++++++++ .../train_dreambooth_lora_flux_advanced.py | 21 +++++++-- examples/dreambooth/README_flux.md | 17 +++++++ 4 files changed, 98 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index c05fa26cf9..62f9078949 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t > `pip install wandb` > Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`. +### LoRA Rank and Alpha +Two key LoRA hyperparameters are LoRA rank and LoRA alpha. +- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters). +- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank. +- lora_alpha vs. rank: +This ratio dictates the LoRA's effective strength: +lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16) +lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16) +lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16) + +> [!TIP] +> A common starting point is to set `lora_alpha` equal to `rank`. +> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) +> to give the LoRA updates more influence without increasing parameter count. +> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank` +> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case. + + ### Target Modules When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py index e29c998213..d465b7de85 100644 --- a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import sys @@ -20,6 +21,8 @@ import tempfile import safetensors +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate): run_command(self._launch_args + resume_run_args) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index bddab8227a..173d3bfd5b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -55,6 +55,7 @@ from diffusers import ( ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _collate_lora_metadata, _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, @@ -431,6 +432,13 @@ def parse_args(input_args=None): help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument( @@ -1556,7 +1564,7 @@ def main(args): # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=target_modules, @@ -1565,7 +1573,7 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], @@ -1582,13 +1590,15 @@ def main(args): if accelerator.is_main_process: transformer_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None - + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model elif isinstance(model, type(unwrap_model(text_encoder_one))): if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["text_encoder"] = model elif isinstance(model, type(unwrap_model(text_encoder_two))): pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers else: @@ -1601,6 +1611,7 @@ def main(args): output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) if args.train_text_encoder_ti: embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") @@ -2359,16 +2370,19 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + modules_to_save = {} transformer = unwrap_model(transformer) if args.upcast_before_saving: transformer.to(torch.float32) else: transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + modules_to_save["text_encoder"] = text_encoder_one else: text_encoder_lora_layers = None @@ -2377,6 +2391,7 @@ def main(args): save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, + **_collate_lora_metadata(modules_to_save), ) if args.train_text_encoder_ti: diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index aa43b00faf..a3704f2789 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +### LoRA Rank and Alpha +Two key LoRA hyperparameters are LoRA rank and LoRA alpha. +- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters). +- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank. +- lora_alpha vs. rank: +This ratio dictates the LoRA's effective strength: +lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16) +lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16) +lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16) + +> [!TIP] +> A common starting point is to set `lora_alpha` equal to `rank`. +> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) +> to give the LoRA updates more influence without increasing parameter count. +> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank` +> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case. + ### Target Modules When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore From 5ce4814af1de6d2dc2cc67a46d3862ce62261e2b Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Tue, 17 Jun 2025 20:16:03 -0700 Subject: [PATCH 7/8] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method=20?= =?UTF-8?q?`AutoencoderKLWan.clear=5Fcache`=20by=20886%=20(#11665)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ⚡️ Speed up method `AutoencoderKLWan.clear_cache` by 886% **Key optimizations:** - Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling). - The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency. All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines. **Function signatures and outputs remain unchanged.** * Apply style fixes * Apply suggestions from code review Co-authored-by: Aryan * Apply style fixes --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] Co-authored-by: Aryan Co-authored-by: Aryan Co-authored-by: Aseem Saxena --- .../models/autoencoders/autoencoder_kl_wan.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index fe00d8c078..49cefcd8a1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.use_slicing = False def clear_cache(self): - def _count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] self._conv_idx = [0] self._feat_map = [None] * self._conv_num # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_num = self._cached_conv_counts["encoder"] self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num From d72184eba358b883d7186a0a96dedd8118fcb72a Mon Sep 17 00:00:00 2001 From: Leo Jiang Date: Tue, 17 Jun 2025 21:56:02 -0600 Subject: [PATCH 8/8] [training] add ds support to lora hidream (#11737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [training] add ds support to lora hidream * Apply style fixes --------- Co-authored-by: J石页 Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] --- .../train_dreambooth_lora_hidream.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index f368fb809e..a1337e8dba 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -29,7 +29,7 @@ from pathlib import Path import numpy as np import torch import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1181,13 +1181,15 @@ def main(args): transformer_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() HiDreamImagePipeline.save_lora_weights( output_dir, @@ -1197,13 +1199,20 @@ def main(args): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = HiDreamImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir) @@ -1655,7 +1664,7 @@ def main(args): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: