From b320c6b53e127dc9b6e2bd9cec32bad293673173 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 19 Apr 2023 10:46:51 -0700 Subject: [PATCH] controlnet training resize inputs to multiple of 8 (#3135) controlnet training center crop input images to multiple of 8 The pipeline code resizes inputs to multiples of 8. Not doing this resizing in the training script is causing the encoded image to have different height/width dimensions than the encoded conditioning image (which uses a separate encoder that's part of the controlnet model). We resize and center crop the inputs to make sure they're the same size (as well as all other images in the batch). We also check that the initial resolution is a multiple of 8. --- examples/controlnet/train_controlnet.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index c0b52291fc..d52e610ca5 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -525,6 +525,11 @@ def parse_args(input_args=None): " or the same number of `--validation_prompt`s and `--validation_image`s" ) + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + return args @@ -607,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator): image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -615,6 +621,7 @@ def make_train_dataset(args, tokenizer, accelerator): conditioning_image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] )