diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 46fc692bef..e493f89125 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -52,6 +52,9 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.29.0.dev0") @@ -99,6 +102,48 @@ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on model_card.save(os.path.join(repo_folder, "README.md")) +def log_validation( + pipeline, + args, + accelerator, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + return images + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -414,11 +459,6 @@ def main(): if torch.backends.mps.is_available(): accelerator.native_amp = False - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -864,10 +904,6 @@ def main(): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -876,38 +912,7 @@ def main(): variant=args.variant, torch_dtype=weight_dtype, ) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device) - if args.seed is not None: - generator = generator.manual_seed(args.seed) - images = [] - if torch.backends.mps.is_available(): - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type) - - with autocast_ctx: - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + images = log_validation(pipeline, args, accelerator, epoch) del pipeline torch.cuda.empty_cache() @@ -925,6 +930,22 @@ def main(): safe_serialization=True, ) + # Final inference + # Load previous pipeline + if args.validation_prompt is not None: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True) + if args.push_to_hub: save_model_card( repo_id, @@ -940,51 +961,6 @@ def main(): ignore_patterns=["step_*", "epoch_*"], ) - # Final inference - # Load previous pipeline - if args.validation_prompt is not None: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(accelerator.device) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference - generator = torch.Generator(device=accelerator.device) - if args.seed is not None: - generator = generator.manual_seed(args.seed) - images = [] - if torch.backends.mps.is_available(): - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type) - - with autocast_ctx: - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) - - for tracker in accelerator.trackers: - if len(images) != 0: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - accelerator.end_training()