mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* memory improvement as done here: https://github.com/huggingface/diffusers/pull/9829 * fix bug * fix bug * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -2154,6 +2154,7 @@ def main(args):
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat = 1
|
||||
if freeze_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
@@ -2168,17 +2169,21 @@ def main(args):
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
add_special_tokens=add_special_tokens_t5,
|
||||
)
|
||||
else:
|
||||
elems_to_repeat = len(prompts)
|
||||
|
||||
if not freeze_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
text_input_ids_list=[
|
||||
tokens_one.repeat(elems_to_repeat, 1),
|
||||
tokens_two.repeat(elems_to_repeat, 1),
|
||||
],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=prompts,
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
@@ -2371,6 +2376,9 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
if freeze_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
@@ -2448,6 +2456,8 @@ def main(args):
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -1648,11 +1648,15 @@ def main(args):
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
elems_to_repeat = len(prompts)
|
||||
if args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
text_input_ids_list=[
|
||||
tokens_one.repeat(elems_to_repeat, 1),
|
||||
tokens_two.repeat(elems_to_repeat, 1),
|
||||
],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
|
||||
Reference in New Issue
Block a user