mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training] fix: registration of out_channels in the control flux scripts. (#10367)
* fix: registration of out_channels in the control flux scripts. * free memory.
This commit is contained in:
@@ -795,7 +795,7 @@ def main(args):
|
||||
flux_transformer.x_embedder = new_linear
|
||||
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
@@ -1166,6 +1166,11 @@ def main(args):
|
||||
flux_transformer.to(torch.float32)
|
||||
flux_transformer.save_pretrained(args.output_dir)
|
||||
|
||||
del flux_transformer
|
||||
del text_encoding_pipeline
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Run a final round of validation.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
|
||||
@@ -830,7 +830,7 @@ def main(args):
|
||||
flux_transformer.x_embedder = new_linear
|
||||
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
@@ -1319,6 +1319,11 @@ def main(args):
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
del flux_transformer
|
||||
del text_encoding_pipeline
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Run a final round of validation.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
|
||||
Reference in New Issue
Block a user