1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[advanced flux training] bug fix + reduce memory cost as in #9829 (#9838)

* memory improvement as done here: https://github.com/huggingface/diffusers/pull/9829

* fix bug

* fix bug

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Linoy Tsaban
2024-11-18 23:13:36 -04:00
committed by GitHub
parent 03bf77c4af
commit acf479bded
2 changed files with 17 additions and 3 deletions

View File

@@ -2154,6 +2154,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
elems_to_repeat = 1
if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
@@ -2168,17 +2169,21 @@ def main(args):
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
else:
elems_to_repeat = len(prompts)
if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
@@ -2371,6 +2376,9 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
@@ -2448,6 +2456,8 @@ def main(args):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
accelerator.end_training()

View File

@@ -1648,11 +1648,15 @@ def main(args):
prompt=prompts,
)
else:
elems_to_repeat = len(prompts)
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt,