1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add FSDP option for Flux2

This commit is contained in:
js1234567
2025-12-24 17:11:05 +08:00
parent 6cfac4642f
commit af339debf4
4 changed files with 10 additions and 13 deletions

View File

@@ -100,7 +100,7 @@ This way, the text encoder model is not loaded into memory during training.
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, the memory cost can be distributed in multiple nodes.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching

View File

@@ -44,6 +44,7 @@ import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -63,7 +64,6 @@ from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
from typing import Any
import diffusers
from diffusers import (
@@ -1292,7 +1292,7 @@ def main(args):
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(models) if is_fsdp else None
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
@@ -1302,8 +1302,8 @@ def main(args):
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model
** peft_kwargs,
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:

View File

@@ -43,6 +43,7 @@ import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -61,7 +62,6 @@ from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
from typing import Any
import diffusers
from diffusers import (
@@ -75,7 +75,7 @@ from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
@@ -1229,7 +1229,7 @@ def main(args):
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(models) if is_fsdp else None
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
@@ -1239,7 +1239,7 @@ def main(args):
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)

View File

@@ -403,10 +403,7 @@ def find_nearest_bucket(h, w, bucket_options):
def _to_cpu_contiguous(state_dicts) -> dict:
return {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in state_dicts.items()
}
return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()}
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: