mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] add: tests for t2i adapter training. (#4947)
add: tests for t2i adapter training.
This commit is contained in:
@@ -245,6 +245,13 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained adapter model or model identifier from huggingface.co/models."
|
||||
" If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -840,14 +847,18 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
logger.info("Initializing t2iadapter weights from unet")
|
||||
t2iadapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=(320, 640, 1280, 1280),
|
||||
num_res_blocks=2,
|
||||
downscale_factor=16,
|
||||
adapter_type="full_adapter_xl",
|
||||
)
|
||||
if args.adapter_model_name_or_path:
|
||||
logger.info("Loading existing adapter weights.")
|
||||
t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
|
||||
else:
|
||||
logger.info("Initializing t2iadapter weights.")
|
||||
t2iadapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=(320, 640, 1280, 1280),
|
||||
num_res_blocks=2,
|
||||
downscale_factor=16,
|
||||
adapter_type="full_adapter_xl",
|
||||
)
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
|
||||
@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
def test_t2i_adapter_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/t2i_adapter/train_t2i_adapter_sdxl.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
Reference in New Issue
Block a user