From 4f14b363297cf8deac3e88a3bf31f59880ac8a96 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 31 May 2023 09:39:31 -0700 Subject: [PATCH] Full Dreambooth IF stage II upscaling (#3561) * update dreambooth lora to work with IF stage II * Update dreambooth script for IF stage II upscaler --- examples/dreambooth/train_dreambooth.py | 55 +++++++++++++++++++++---- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 37b06acb69..e4ab6b2ae0 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -52,6 +52,7 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import randn_tensor if is_wandb_available(): @@ -114,16 +115,17 @@ def log_validation( pipeline_args = {} - if text_encoder is not None: - pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) - if vae is not None: pipeline_args["vae"] = vae + if text_encoder is not None: + text_encoder = accelerator.unwrap_model(text_encoder) + # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, + text_encoder=text_encoder, unet=accelerator.unwrap_model(unet), revision=args.revision, torch_dtype=weight_dtype, @@ -156,10 +158,16 @@ def log_validation( # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -525,6 +533,19 @@ def parse_args(input_args=None): parser.add_argument( "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1169,7 +1190,7 @@ def main(args): ) else: noise = torch.randn_like(model_input) - bsz = model_input.shape[0] + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device @@ -1191,8 +1212,24 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) + if unet.config.in_channels > channels: + needed_additional_channels = unet.config.in_channels - channels + additional_latents = randn_tensor( + (bsz, needed_additional_channels, height, width), + device=noisy_model_input.device, + dtype=noisy_model_input.dtype, + ) + noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + # Predict the noise residual - model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1)