1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[training] fix: registration of out_channels in the control flux scripts. (#10367)

* fix: registration of out_channels in the control flux scripts.

* free memory.
This commit is contained in:
Sayak Paul
2024-12-24 21:44:44 +05:30
committed by GitHub
parent 023b0e0d55
commit 825979ddc3
2 changed files with 12 additions and 2 deletions

View File

@@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
@@ -1166,6 +1166,11 @@ def main(args):
flux_transformer.to(torch.float32)
flux_transformer.save_pretrained(args.output_dir)
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:

View File

@@ -830,7 +830,7 @@ def main(args):
flux_transformer.x_embedder = new_linear
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
@@ -1319,6 +1319,11 @@ def main(args):
transformer_lora_layers=transformer_lora_layers,
)
del flux_transformer
del text_encoding_pipeline
del vae
free_memory()
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None: