1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add --lora_alpha and metadata handling to train_dreambooth_lora_sana.py (#11744)

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
imbr92
2025-06-23 12:46:44 +00:00
committed by GitHub
parent 798265f2b6
commit 6760300202
2 changed files with 56 additions and 4 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
@@ -20,6 +21,8 @@ import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -204,3 +207,42 @@ class DreamBoothLoRASANA(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_sana_with_metadata(self):
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--lora_alpha={lora_alpha}
--rank={rank}
--checkpointing_steps=2
--max_sequence_length 166
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

View File

@@ -52,6 +52,7 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1023,7 +1028,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
@@ -1039,10 +1044,11 @@ def main(args):
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1052,6 +1058,7 @@ def main(args):
SanaPipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
def load_model_hook(models, input_dir):
@@ -1507,15 +1514,18 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
modules_to_save = {}
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
SanaPipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
# Final inference