1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] add: tests for t2i adapter training. (#4947)

add: tests for t2i adapter training.
This commit is contained in:
Sayak Paul
2023-09-08 19:45:39 +05:30
committed by GitHub
parent dfec61f4b3
commit d0cf681a1f
2 changed files with 38 additions and 8 deletions

View File

@@ -245,6 +245,13 @@ def parse_args(input_args=None):
default=None,
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
)
parser.add_argument(
"--adapter_model_name_or_path",
type=str,
default=None,
help="Path to pretrained adapter model or model identifier from huggingface.co/models."
" If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
)
parser.add_argument(
"--revision",
type=str,
@@ -840,14 +847,18 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
logger.info("Initializing t2iadapter weights from unet")
t2iadapter = T2IAdapter(
in_channels=3,
channels=(320, 640, 1280, 1280),
num_res_blocks=2,
downscale_factor=16,
adapter_type="full_adapter_xl",
)
if args.adapter_model_name_or_path:
logger.info("Loading existing adapter weights.")
t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
else:
logger.info("Initializing t2iadapter weights.")
t2iadapter = T2IAdapter(
in_channels=3,
channels=(320, 640, 1280, 1280),
num_res_blocks=2,
downscale_factor=16,
adapter_type="full_adapter_xl",
)
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):

View File

@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_t2i_adapter_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/t2i_adapter/train_t2i_adapter_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""