diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 8e932add92..7fea4fdb64 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -761,6 +761,7 @@ def main(): num_cycles=args.lr_num_cycles, ) + text_encoder.train() # Prepare everything with our `accelerator`. text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler