diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e160e00c70..498269b05f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -359,7 +359,7 @@ def main(): sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with torch.no_grad(): - images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50)["sample"] + images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50).images for image, index in zip(images, example["index"]): image.save(class_images_dir / f"{index + cur_class_images}.jpg") @@ -450,9 +450,6 @@ def main(): text_encoder.to(accelerator.device) vae.to(accelerator.device) - # Keep text_encoder and vae in eval model as we don't train it - text_encoder.eval() - vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)