diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index ad862c6976..8b46d2143b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -999,7 +999,7 @@ def main(args): validation_prompt_encoder_hidden_states = None if args.class_prompt is not None: - pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) else: pre_computed_class_prompt_encoder_hidden_states = None