mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into group-offloading-with-disk
This commit is contained in:
@@ -180,6 +180,8 @@
|
||||
title: Caching
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
title: xFormers
|
||||
- local: optimization/tome
|
||||
|
||||
187
docs/source/en/optimization/pruna.md
Normal file
187
docs/source/en/optimization/pruna.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# Pruna
|
||||
|
||||
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
|
||||
|
||||
|
||||
| Technique | Description | Speed | Memory | Quality |
|
||||
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
|
||||
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅ | ❌ | ➖ |
|
||||
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. | ✅ | ➖ | ➖ |
|
||||
| `compiler` | Optimises the model with instructions for specific hardware. | ✅ | ➖ | ➖ |
|
||||
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. | ✅ | ✅ | ❌ |
|
||||
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. | ✅ | ✅ | ❌ |
|
||||
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅ | ✅ | ❌ |
|
||||
| `recoverer` | Restores the performance of a model after compression. | ➖ | ➖ | ✅ |
|
||||
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | ➖ | ➖ |
|
||||
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ |
|
||||
|
||||
✅ (improves), ➖ (approx. the same), ❌ (worsens)
|
||||
|
||||
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
|
||||
|
||||
## Installation
|
||||
|
||||
Install Pruna with the following command.
|
||||
|
||||
```bash
|
||||
pip install pruna
|
||||
```
|
||||
|
||||
|
||||
## Optimize Diffusers models
|
||||
|
||||
A broad range of optimization algorithms are supported for Diffusers models as shown below.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
|
||||
</div>
|
||||
|
||||
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
|
||||
</div>
|
||||
|
||||
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel, SmashConfig, smash
|
||||
|
||||
# load the model
|
||||
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# define the configuration
|
||||
smash_config = SmashConfig()
|
||||
smash_config["factorizer"] = "qkv_diffusers"
|
||||
smash_config["compiler"] = "torch_compile"
|
||||
smash_config["torch_compile_target"] = "module_list"
|
||||
smash_config["cacher"] = "fora"
|
||||
smash_config["fora_interval"] = 2
|
||||
|
||||
# for the best results in terms of speed you can add these configs
|
||||
# however they will increase your warmup time from 1.5 min to 10 min
|
||||
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
|
||||
# smash_config["quantizer"] = "torchao"
|
||||
# smash_config["torchao_quant_type"] = "fp8dq"
|
||||
# smash_config["torchao_excluded_modules"] = "norm+embedding"
|
||||
|
||||
# optimize the model
|
||||
smashed_pipe = smash(pipe, smash_config)
|
||||
|
||||
# run the model
|
||||
smashed_pipe("a knitted purple prune").images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
|
||||
</div>
|
||||
|
||||
After optimization, we can share and load the optimized model using the Hugging Face Hub.
|
||||
|
||||
```python
|
||||
# save the model
|
||||
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
|
||||
|
||||
# load the model
|
||||
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
|
||||
```
|
||||
|
||||
## Evaluate and benchmark Diffusers models
|
||||
|
||||
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
|
||||
|
||||
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="optimized model">
|
||||
|
||||
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel
|
||||
from pruna.data.pruna_datamodule import PrunaDataModule
|
||||
from pruna.evaluation.evaluation_agent import EvaluationAgent
|
||||
from pruna.evaluation.metrics import (
|
||||
ThroughputMetric,
|
||||
TorchMetricWrapper,
|
||||
TotalTimeMetric,
|
||||
)
|
||||
from pruna.evaluation.task import Task
|
||||
|
||||
# define the device
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# load the model
|
||||
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
|
||||
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
|
||||
|
||||
# Define the metrics
|
||||
metrics = [
|
||||
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
|
||||
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
|
||||
TorchMetricWrapper("clip"),
|
||||
]
|
||||
|
||||
# Define the datamodule
|
||||
datamodule = PrunaDataModule.from_string("LAION256")
|
||||
datamodule.limit_datasets(10)
|
||||
|
||||
# Define the task and evaluation agent
|
||||
task = Task(metrics, datamodule=datamodule, device=device)
|
||||
eval_agent = EvaluationAgent(task)
|
||||
|
||||
# Evaluate smashed model and offload it to CPU
|
||||
smashed_pipe.move_to_device(device)
|
||||
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
|
||||
smashed_pipe.move_to_device("cpu")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="standalone model">
|
||||
|
||||
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel
|
||||
|
||||
# load the model
|
||||
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cpu")
|
||||
wrapped_pipe = PrunaModel(model=pipe)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
|
||||
|
||||
> [!TIP]
|
||||
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
|
||||
|
||||
## Reference
|
||||
|
||||
- [Pruna](https://github.com/pruna-ai/pruna)
|
||||
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
|
||||
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
|
||||
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
|
||||
|
||||
@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
|
||||
> `pip install wandb`
|
||||
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -20,6 +21,8 @@ import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
@@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
|
||||
@@ -55,6 +55,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
@@ -431,6 +432,13 @@ def parse_args(input_args=None):
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=4,
|
||||
help="LoRA alpha to be used for additional scaling.",
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -1556,7 +1564,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
@@ -1565,7 +1573,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
@@ -1582,13 +1590,15 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["text_encoder"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
|
||||
else:
|
||||
@@ -1601,6 +1611,7 @@ def main(args):
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
|
||||
@@ -2359,16 +2370,19 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
modules_to_save["text_encoder"] = text_encoder_one
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
@@ -2377,6 +2391,7 @@ def main(args):
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
|
||||
@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -20,6 +21,8 @@ import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
@@ -234,3 +237,45 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
|
||||
@@ -27,7 +27,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -53,6 +52,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=4,
|
||||
help="LoRA alpha to be used for additional scaling.",
|
||||
)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -1238,7 +1243,7 @@ def main(args):
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
@@ -1247,7 +1252,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
@@ -1264,12 +1269,14 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["text_encoder"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1280,6 +1287,7 @@ def main(args):
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -1889,16 +1897,19 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
modules_to_save["text_encoder"] = text_encoder_one
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
@@ -1906,6 +1917,7 @@ def main(args):
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -29,7 +29,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
@@ -1181,13 +1181,15 @@ def main(args):
|
||||
transformer_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
HiDreamImagePipeline.save_lora_weights(
|
||||
output_dir,
|
||||
@@ -1197,13 +1199,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer"
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1655,7 +1664,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
|
||||
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
|
||||
min_block = min(block_numbers)
|
||||
max_block = max(block_numbers)
|
||||
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
|
||||
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
|
||||
@@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
# For the `diff_b` keys, we treat them as lora_bias.
|
||||
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
|
||||
|
||||
for i in range(num_blocks):
|
||||
for i in range(min_block, max_block + 1):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
@@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# Remaining.
|
||||
if original_state_dict:
|
||||
|
||||
@@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
||||
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=None,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
is_xlabs = any("processor" in k for k in state_dict)
|
||||
if is_xlabs:
|
||||
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
||||
# xlabs doesn't use `alpha`.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=None,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
||||
if is_bfl_control:
|
||||
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=None,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
@@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if return_alphas or return_lora_metadata:
|
||||
outputs = [state_dict]
|
||||
if return_alphas:
|
||||
outputs.append(network_alphas)
|
||||
if return_lora_metadata:
|
||||
outputs.append(metadata)
|
||||
return tuple(outputs)
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=network_alphas,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
@@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
|
||||
outputs = [state_dict]
|
||||
if return_alphas:
|
||||
outputs.append(alphas)
|
||||
if return_metadata:
|
||||
outputs.append(metadata)
|
||||
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
|
||||
|
||||
|
||||
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
||||
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
||||
|
||||
@@ -187,7 +187,9 @@ class PeftAdapterMixin:
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
metadata: TODO
|
||||
metadata:
|
||||
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
|
||||
initialize `LoraConfig`.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
||||
self._cached_conv_counts = {
|
||||
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
||||
if self.decoder is not None
|
||||
else 0,
|
||||
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
||||
if self.encoder is not None
|
||||
else 0,
|
||||
}
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
||||
self._conv_num = self._cached_conv_counts["decoder"]
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
@@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder(
|
||||
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
|
||||
metadatas = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
||||
return metadatas
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str,
|
||||
batch_size: int,
|
||||
|
||||
@@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str):
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
return json.loads(raw) if raw else None
|
||||
if metadata:
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
return json.loads(raw) if raw else None
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user