diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 532e134a61..d1e9c28e05 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -789,10 +789,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte else: attention_mask = None - prompt_embeds = text_encoder( - text_input_ids, - attention_mask=attention_mask, - return_dict=False, + prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask, return_dict=False) ) prompt_embeds = prompt_embeds[0]