From fed12376c5c6dda416f7408aec69e242f227c810 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 25 Jul 2023 05:31:48 +0530 Subject: [PATCH] [ControlNet SDXL training] fixes in the training script (#4223) * fix: #4206 * add: sdxl controlnet training smoketest. * remove unnecessary token inits. * add: licensing to model card. * include SDXL licensing in the model card and make public visibility default * debugging * debugging * disable local file download. * fix: training test. * fix: ckpt prefix. --- examples/controlnet/train_controlnet_sdxl.py | 17 ++++++++++------- examples/test_examples.py | 19 +++++++++++++++++++ src/diffusers/models/controlnet.py | 1 - 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b394565058..7acb1c259b 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline( - validation_prompt, validation_image, num_inference_steps=20, generator=generator + prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator ).images[0] images.append(image) @@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] @@ -226,6 +226,12 @@ inference: true These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} +""" + model_card += """ + +## License + +[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md) """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -798,10 +804,7 @@ def main(args): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - token=args.hub_token, - private=True, + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizers @@ -839,7 +842,7 @@ def main(args): revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) if args.controlnet_model_name_or_path: diff --git a/examples/test_examples.py b/examples/test_examples.py index cc3b3fbf74..1492cd7184 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1296,6 +1296,25 @@ class ExamplesTestsAccelerate(unittest.TestCase): {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, ) + def test_controlnet_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet_sdxl.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin"))) + def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 354cd5851d..ed3f3e6871 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -751,7 +751,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + controlnet_cond # 3. down