mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add FSDP option for Flux2
This commit is contained in:
@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### FSDP Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, the memory cost can be distributed in multiple nodes.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
|
||||
@@ -47,7 +47,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -64,6 +63,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms import functional as TF
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
|
||||
from typing import Any
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
@@ -76,6 +76,7 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
@@ -96,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -1271,43 +1275,44 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(models) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
|
||||
if is_fsdp:
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
state_dict = accelerator.get_state_dict(models)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers_to_save.items()
|
||||
}
|
||||
modules_to_save["transformer"] = model
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
else:
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model
|
||||
** peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
|
||||
@@ -46,7 +46,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -62,6 +61,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms import functional as TF
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
|
||||
from typing import Any
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
@@ -75,6 +75,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
@@ -96,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -1208,42 +1212,44 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(models) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
if is_fsdp:
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
state_dict = accelerator.get_state_dict(models)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers_to_save.items()
|
||||
}
|
||||
modules_to_save["transformer"] = model
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
else:
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
|
||||
@@ -402,6 +402,13 @@ def find_nearest_bucket(h, w, bucket_options):
|
||||
return best_bucket_idx
|
||||
|
||||
|
||||
def _to_cpu_contiguous(state_dicts) -> dict:
|
||||
return {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dicts.items()
|
||||
}
|
||||
|
||||
|
||||
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
||||
"""
|
||||
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
|
||||
|
||||
Reference in New Issue
Block a user