From 825979ddc3d03462287f1f5439e89ccac8cc71e9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Dec 2024 21:44:44 +0530 Subject: [PATCH] [training] fix: registration of out_channels in the control flux scripts. (#10367) * fix: registration of out_channels in the control flux scripts. * free memory. --- examples/flux-control/train_control_flux.py | 7 ++++++- examples/flux-control/train_control_lora_flux.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 1432e346f0..35f9a5f803 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -795,7 +795,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1166,6 +1166,11 @@ def main(args): flux_transformer.to(torch.float32) flux_transformer.save_pretrained(args.output_dir) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 6d84e81d81..b176a685c9 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -830,7 +830,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) if args.train_norm_layers: for name, param in flux_transformer.named_parameters(): @@ -1319,6 +1319,11 @@ def main(args): transformer_lora_layers=transformer_lora_layers, ) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: