From f751b8844ebd73bdd9cfd609ea03db10e8fe0f5a Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 31 May 2023 09:39:03 -0700 Subject: [PATCH] update dreambooth lora to work with IF stage II (#3560) --- examples/dreambooth/train_dreambooth_lora.py | 49 +++++++++++++++++-- .../pipeline_if_img2img_superresolution.py | 6 ++- .../pipeline_if_inpainting_superresolution.py | 6 ++- .../pipeline_if_superresolution.py | 6 ++- 4 files changed, 59 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 4ff759dcd6..12b0908918 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -60,6 +60,7 @@ from diffusers.models.attention_processor import ( from diffusers.optimization import get_scheduler from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import randn_tensor # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -425,6 +426,19 @@ def parse_args(input_args=None): required=False, help="Whether to use attention mask for the text encoder", ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1121,7 +1135,7 @@ def main(args): # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) - bsz = model_input.shape[0] + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device @@ -1143,8 +1157,24 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) + if unet.config.in_channels > channels: + needed_additional_channels = unet.config.in_channels - channels + additional_latents = randn_tensor( + (bsz, needed_additional_channels, height, width), + device=noisy_model_input.device, + dtype=noisy_model_input.dtype, + ) + noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + # Predict the noise residual - model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample # if model predicts variance, throw away the prediction. we will only train on the # simplified training objective. This means that all schedulers using the fine tuned @@ -1248,9 +1278,18 @@ def main(args): } else: pipeline_args = {"prompt": args.validation_prompt} - images = [ - pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) - ] + + if args.validation_images is None: + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index a49d25137b..0ee9c6ba33 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """ """ -class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -1047,6 +1048,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index f255948dc7..6a90f2b765 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -114,7 +115,7 @@ EXAMPLE_DOC_STRING = """ """ -class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -1154,6 +1155,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index 7a8de51579..86d9574b97 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -70,7 +71,7 @@ EXAMPLE_DOC_STRING = """ """ -class IFSuperResolutionPipeline(DiffusionPipeline): +class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -903,6 +904,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False