diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index e25c7f1669..ff502d9309 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1281,7 +1281,8 @@ def main(args): if accelerator.is_main_process: transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), state_dict=state_dict, + unwrap_model(model), + state_dict=state_dict, ) transformer_lora_layers_to_save = { k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v @@ -1326,7 +1327,8 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") else: transformer_ = Flux2Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", + args.pretrained_model_name_or_path, + subfolder="transformer", ) transformer_.add_adapter(transformer_lora_config) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 2062994a0d..cd2d493b18 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1217,7 +1217,8 @@ def main(args): if accelerator.is_main_process: transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), state_dict=state_dict, + unwrap_model(model), + state_dict=state_dict, ) transformer_lora_layers_to_save = { k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v @@ -1262,7 +1263,8 @@ def main(args): raise ValueError(f"unexpected save model: {model.__class__}") else: transformer_ = Flux2Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", + args.pretrained_model_name_or_path, + subfolder="transformer", ) transformer_.add_adapter(transformer_lora_config) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 56e5fe4e5a..9b09d2c814 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,13 +5,15 @@ import math import random import re import warnings -from accelerate.logging import get_logger from contextlib import contextmanager from functools import partial from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch +from accelerate.logging import get_logger + + if getattr(torch, "distributed", None) is not None: from torch.distributed.fsdp import CPUOffload, ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP