From d0cf681a1ff3dc2c394c40d877e8fcaf7caa8080 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 8 Sep 2023 19:45:39 +0530 Subject: [PATCH] [Tests] add: tests for t2i adapter training. (#4947) add: tests for t2i adapter training. --- .../t2i_adapter/train_t2i_adapter_sdxl.py | 27 +++++++++++++------ examples/test_examples.py | 19 +++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index b09c7706a2..3d846f42f6 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -245,6 +245,13 @@ def parse_args(input_args=None): default=None, help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", ) + parser.add_argument( + "--adapter_model_name_or_path", + type=str, + default=None, + help="Path to pretrained adapter model or model identifier from huggingface.co/models." + " If not specified adapter weights are initialized w.r.t the configurations of SDXL.", + ) parser.add_argument( "--revision", type=str, @@ -840,14 +847,18 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - logger.info("Initializing t2iadapter weights from unet") - t2iadapter = T2IAdapter( - in_channels=3, - channels=(320, 640, 1280, 1280), - num_res_blocks=2, - downscale_factor=16, - adapter_type="full_adapter_xl", - ) + if args.adapter_model_name_or_path: + logger.info("Loading existing adapter weights.") + t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path) + else: + logger.info("Initializing t2iadapter weights.") + t2iadapter = T2IAdapter( + in_channels=3, + channels=(320, 640, 1280, 1280), + num_res_blocks=2, + downscale_factor=16, + adapter_type="full_adapter_xl", + ) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): diff --git a/examples/test_examples.py b/examples/test_examples.py index 4da3cbad48..89e866231e 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) + def test_t2i_adapter_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/t2i_adapter/train_t2i_adapter_sdxl.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe + --adapter_model_name_or_path=hf-internal-testing/tiny-adapter + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) + def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f"""