From fdd003d8e2287e4dcdfc14a7d2fc9444e55c1904 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Nov 2023 18:43:59 +0530 Subject: [PATCH] [Tests] Refactor `test_examples.py` for better readability (#5946) * control and custom diffusion * dreambooth * instructpix2pix and dreambooth ckpting * t2i adapters. * text to image ft * textual inversion * unconditional * workflows * import fix * fix import --- .github/workflows/pr_tests.yml | 2 +- .github/workflows/push_tests_fast.yml | 2 +- examples/controlnet/test_controlnet.py | 120 ++ .../custom_diffusion/test_custom_diffusion.py | 130 ++ examples/dreambooth/test_dreambooth.py | 230 +++ examples/dreambooth/test_dreambooth_lora.py | 388 ++++ .../instruct_pix2pix/test_instruct_pix2pix.py | 101 + examples/t2i_adapter/test_t2i_adapter.py | 51 + examples/test_examples.py | 1725 ----------------- examples/test_examples_utils.py | 61 + examples/text_to_image/test_text_to_image.py | 373 ++++ .../text_to_image/test_text_to_image_lora.py | 308 +++ .../test_textual_inversion.py | 160 ++ .../test_unconditional.py | 130 ++ 14 files changed, 2054 insertions(+), 1727 deletions(-) create mode 100644 examples/controlnet/test_controlnet.py create mode 100644 examples/custom_diffusion/test_custom_diffusion.py create mode 100644 examples/dreambooth/test_dreambooth.py create mode 100644 examples/dreambooth/test_dreambooth_lora.py create mode 100644 examples/instruct_pix2pix/test_instruct_pix2pix.py create mode 100644 examples/t2i_adapter/test_t2i_adapter.py delete mode 100644 examples/test_examples.py create mode 100644 examples/test_examples_utils.py create mode 100644 examples/text_to_image/test_text_to_image.py create mode 100644 examples/text_to_image/test_text_to_image_lora.py create mode 100644 examples/textual_inversion/test_textual_inversion.py create mode 100644 examples/unconditional_image_generation/test_unconditional.py diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index f7d9dde525..ff687db5c1 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -115,7 +115,7 @@ jobs: run: | python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ - examples/test_examples.py + examples - name: Failure short reports if: ${{ failure() }} diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 798fa777c6..6ea873d0a7 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -100,7 +100,7 @@ jobs: run: | python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ - examples/test_examples.py + examples - name: Failure short reports if: ${{ failure() }} diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py new file mode 100644 index 0000000000..e62d095ada --- /dev/null +++ b/examples/controlnet/test_controlnet.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class ControlNet(ExamplesTestsAccelerate): + def test_controlnet_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, + ) + + +class ControlNetSDXL(ExamplesTestsAccelerate): + def test_controlnet_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet_sdxl.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) diff --git a/examples/custom_diffusion/test_custom_diffusion.py b/examples/custom_diffusion/test_custom_diffusion.py new file mode 100644 index 0000000000..78f24c5172 --- /dev/null +++ b/examples/custom_diffusion/test_custom_diffusion.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class CustomDiffusion(ExamplesTestsAccelerate): + def test_custom_diffusion(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1.0e-05 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --modifier_token + --no_safe_serialization + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin"))) + + def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --no_safe_serialization + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=9 + --checkpointing_steps=2 + --no_safe_serialization + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + --no_safe_serialization + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) diff --git a/examples/dreambooth/test_dreambooth.py b/examples/dreambooth/test_dreambooth.py new file mode 100644 index 0000000000..0c6c2a0623 --- /dev/null +++ b/examples/dreambooth/test_dreambooth.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile + +from diffusers import DiffusionPipeline, UNet2DConditionModel + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBooth(ExamplesTestsAccelerate): + def test_dreambooth(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_dreambooth_if(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_dreambooth_checkpointing(self): + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 5, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt {instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 5 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check can run the original fully trained output pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(instance_prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + + # check can run an intermediate checkpoint + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) + pipe(instance_prompt, num_inference_steps=2) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 7 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt {instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(instance_prompt, num_inference_steps=2) + + # check old checkpoints do not exist + self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + + # check new checkpoints exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + + def test_dreambooth_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) diff --git a/examples/dreambooth/test_dreambooth_lora.py b/examples/dreambooth/test_dreambooth_lora.py new file mode 100644 index 0000000000..fc43269f73 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + +from diffusers import DiffusionPipeline # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRA(ExamplesTestsAccelerate): + def test_dreambooth_lora(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --train_text_encoder + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # check `text_encoder` is present at all. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + keys = lora_state_dict.keys() + is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) + self.assertTrue(is_text_encoder_present) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. + is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) + self.assertTrue(is_correct_naming) + + def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_dreambooth_lora_if_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + +class DreamBoothLoRASDXL(ExamplesTestsAccelerate): + def test_dreambooth_lora_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_sdxl_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. + keys = lora_state_dict.keys() + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys + ) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_sdxl_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + + def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe("a prompt", num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --train_text_encoder + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe("a prompt", num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) diff --git a/examples/instruct_pix2pix/test_instruct_pix2pix.py b/examples/instruct_pix2pix/test_instruct_pix2pix.py new file mode 100644 index 0000000000..c4d7500723 --- /dev/null +++ b/examples/instruct_pix2pix/test_instruct_pix2pix.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class InstructPix2Pix(ExamplesTestsAccelerate): + def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=9 + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=11 + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) diff --git a/examples/t2i_adapter/test_t2i_adapter.py b/examples/t2i_adapter/test_t2i_adapter.py new file mode 100644 index 0000000000..fe8fd9d8c2 --- /dev/null +++ b/examples/t2i_adapter/test_t2i_adapter.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class T2IAdapter(ExamplesTestsAccelerate): + def test_t2i_adapter_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/t2i_adapter/train_t2i_adapter_sdxl.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe + --adapter_model_name_or_path=hf-internal-testing/tiny-adapter + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) diff --git a/examples/test_examples.py b/examples/test_examples.py deleted file mode 100644 index 292c433a33..0000000000 --- a/examples/test_examples.py +++ /dev/null @@ -1,1725 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc.. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import os -import shutil -import subprocess -import sys -import tempfile -import unittest -from typing import List - -import safetensors -from accelerate.utils import write_basic_config - -from diffusers import DiffusionPipeline, UNet2DConditionModel - - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger() - - -# These utils relate to ensuring the right error message is received when running scripts -class SubprocessCallException(Exception): - pass - - -def run_command(command: List[str], return_stdout=False): - """ - Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture - if an error occurred while running `command` - """ - try: - output = subprocess.check_output(command, stderr=subprocess.STDOUT) - if return_stdout: - if hasattr(output, "decode"): - output = output.decode("utf-8") - return output - except subprocess.CalledProcessError as e: - raise SubprocessCallException( - f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" - ) from e - - -stream_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(stream_handler) - - -class ExamplesTestsAccelerate(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._tmpdir = tempfile.mkdtemp() - cls.configPath = os.path.join(cls._tmpdir, "default_config.yml") - - write_basic_config(save_location=cls.configPath) - cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath] - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - shutil.rmtree(cls._tmpdir) - - def test_train_unconditional(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/unconditional_image_generation/train_unconditional.py - --dataset_name hf-internal-testing/dummy_image_class_data - --model_config_name_or_path diffusers/ddpm_dummy - --resolution 64 - --output_dir {tmpdir} - --train_batch_size 2 - --num_epochs 1 - --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 - --learning_rate 1e-3 - --lr_warmup_steps 5 - """.split() - - run_command(self._launch_args + test_args, return_stdout=True) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - - def test_textual_inversion(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --train_data_dir docs/source/en/imgs - --learnable_property object - --placeholder_token - --initializer_token a - --validation_prompt - --validation_steps 1 - --save_steps 1 - --num_vectors 2 - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors"))) - - def test_dreambooth(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - - def test_dreambooth_if(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --pre_compute_text_embeddings - --tokenizer_max_length=77 - --text_encoder_use_attention_mask - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - - def test_dreambooth_checkpointing(self): - instance_prompt = "photo" - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4 - - initial_run_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --instance_data_dir docs/source/en/imgs - --instance_prompt {instance_prompt} - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 5 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --seed=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - # check can run the original fully trained output pipeline - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) - - # check can run an intermediate checkpoint - unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) - - # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming - shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) - - # Run training script for 7 total steps resuming from checkpoint 4 - - resume_run_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --instance_data_dir docs/source/en/imgs - --instance_prompt {instance_prompt} - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-4 - --seed=0 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check can run new fully trained pipeline - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(instance_prompt, num_inference_steps=2) - - # check old checkpoints do not exist - self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) - - # check new checkpoints exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) - - def test_dreambooth_lora(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"unet"` in their names. - starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) - self.assertTrue(starts_with_unet) - - def test_dreambooth_lora_with_text_encoder(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --train_text_encoder - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # check `text_encoder` is present at all. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - keys = lora_state_dict.keys() - is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) - self.assertTrue(is_text_encoder_present) - - # the names of the keys of the state dict should either start with `unet` - # or `text_encoder`. - is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) - self.assertTrue(is_correct_naming) - - def test_dreambooth_lora_if_model(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --pre_compute_text_embeddings - --tokenizer_max_length=77 - --text_encoder_use_attention_mask - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"unet"` in their names. - starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) - self.assertTrue(starts_with_unet) - - def test_dreambooth_lora_sdxl(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"unet"` in their names. - starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) - self.assertTrue(starts_with_unet) - - def test_dreambooth_lora_sdxl_with_text_encoder(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --train_text_encoder - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. - keys = lora_state_dict.keys() - starts_with_unet = all( - k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys - ) - self.assertTrue(starts_with_unet) - - def test_dreambooth_lora_sdxl_custom_captions(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --caption_column text - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - - def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --caption_column text - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --train_text_encoder - """.split() - - run_command(self._launch_args + test_args) - - def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" - - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path {pipeline_path} - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --checkpointing_steps=2 - --checkpoints_total_limit=2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - - pipe = DiffusionPipeline.from_pretrained(pipeline_path) - pipe.load_lora_weights(tmpdir) - pipe("a prompt", num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" - - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora_sdxl.py - --pretrained_model_name_or_path {pipeline_path} - --instance_data_dir docs/source/en/imgs - --instance_prompt photo - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --checkpointing_steps=2 - --checkpoints_total_limit=2 - --train_text_encoder - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - - pipe = DiffusionPipeline.from_pretrained(pipeline_path) - pipe.load_lora_weights(tmpdir) - pipe("a prompt", num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_custom_diffusion(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/custom_diffusion/train_custom_diffusion.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir docs/source/en/imgs - --instance_prompt - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 1.0e-05 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --modifier_token - --no_safe_serialization - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin"))) - - def test_text_to_image(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - - def test_text_to_image_checkpointing(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4 - - initial_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 5 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --seed=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4"}, - ) - - # check can run an intermediate checkpoint - unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming - shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) - - # Run training script for 7 total steps resuming from checkpoint 4 - - resume_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-4 - --seed=0 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check can run new fully trained pipeline - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - { - # no checkpoint-2 -> check old checkpoints do not exist - # check new checkpoints exist - "checkpoint-4", - "checkpoint-6", - }, - ) - - def test_text_to_image_checkpointing_use_ema(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 5, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4 - - initial_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 5 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --use_ema - --seed=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4"}, - ) - - # check can run an intermediate checkpoint - unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming - shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) - - # Run training script for 7 total steps resuming from checkpoint 4 - - resume_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-4 - --use_ema - --seed=0 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check can run new fully trained pipeline - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - { - # no checkpoint-2 -> check old checkpoints do not exist - # check new checkpoints exist - "checkpoint-4", - "checkpoint-6", - }, - ) - - def test_text_to_image_checkpointing_checkpoints_total_limit(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 - # Should create checkpoints at steps 2, 4, 6 - # with checkpoint at step 2 deleted - - initial_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --checkpoints_total_limit=2 - --seed=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 9, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4, 6, 8 - - initial_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 9 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --seed=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - # resume and we should try to checkpoint at 10, where we'll have to remove - # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint - - resume_run_args = f""" - examples/text_to_image/train_text_to_image.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 11 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - --seed=0 - """.split() - - run_command(self._launch_args + resume_run_args) - - pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_text_to_image_sdxl(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/text_to_image/train_text_to_image_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - - def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 - # Should create checkpoints at steps 2, 4, 6 - # with checkpoint at step 2 deleted - - initial_run_args = f""" - examples/text_to_image/train_text_to_image_lora.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --checkpoints_total_limit=2 - --seed=0 - --num_validation_images=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None - ) - pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): - prompt = "a prompt" - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 - # Should create checkpoints at steps 2, 4, 6 - # with checkpoint at step 2 deleted - - initial_run_args = f""" - examples/text_to_image/train_text_to_image_lora_sdxl.py - --pretrained_model_name_or_path {pipeline_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --checkpoints_total_limit=2 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(pipeline_path) - pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): - prompt = "a prompt" - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 - # Should create checkpoints at steps 2, 4, 6 - # with checkpoint at step 2 deleted - - initial_run_args = f""" - examples/text_to_image/train_text_to_image_lora_sdxl.py - --pretrained_model_name_or_path {pipeline_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 7 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --train_text_encoder - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --checkpoints_total_limit=2 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained(pipeline_path) - pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" - prompt = "a prompt" - - with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 9, checkpointing_steps == 2 - # Should create checkpoints at steps 2, 4, 6, 8 - - initial_run_args = f""" - examples/text_to_image/train_text_to_image_lora.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 9 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --seed=0 - --num_validation_images=0 - """.split() - - run_command(self._launch_args + initial_run_args) - - pipe = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None - ) - pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - # resume and we should try to checkpoint at 10, where we'll have to remove - # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint - - resume_run_args = f""" - examples/text_to_image/train_text_to_image_lora.py - --pretrained_model_name_or_path {pretrained_model_name_or_path} - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --center_crop - --random_flip - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 11 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - --seed=0 - --num_validation_images=0 - """.split() - - run_command(self._launch_args + resume_run_args) - - pipe = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None - ) - pipe.load_lora_weights(tmpdir) - pipe(prompt, num_inference_steps=2) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_unconditional_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - initial_run_args = f""" - examples/unconditional_image_generation/train_unconditional.py - --dataset_name hf-internal-testing/dummy_image_class_data - --model_config_name_or_path diffusers/ddpm_dummy - --resolution 64 - --output_dir {tmpdir} - --train_batch_size 1 - --num_epochs 1 - --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 - --learning_rate 1e-3 - --lr_warmup_steps 5 - --checkpointing_steps=2 - --checkpoints_total_limit=2 - """.split() - - run_command(self._launch_args + initial_run_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - # checkpoint-2 should have been deleted - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - initial_run_args = f""" - examples/unconditional_image_generation/train_unconditional.py - --dataset_name hf-internal-testing/dummy_image_class_data - --model_config_name_or_path diffusers/ddpm_dummy - --resolution 64 - --output_dir {tmpdir} - --train_batch_size 1 - --num_epochs 1 - --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 - --learning_rate 1e-3 - --lr_warmup_steps 5 - --checkpointing_steps=1 - """.split() - - run_command(self._launch_args + initial_run_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"}, - ) - - resume_run_args = f""" - examples/unconditional_image_generation/train_unconditional.py - --dataset_name hf-internal-testing/dummy_image_class_data - --model_config_name_or_path diffusers/ddpm_dummy - --resolution 64 - --output_dir {tmpdir} - --train_batch_size 1 - --num_epochs 2 - --gradient_accumulation_steps 1 - --ddpm_num_inference_steps 2 - --learning_rate 1e-3 - --lr_warmup_steps 5 - --resume_from_checkpoint=checkpoint-6 - --checkpointing_steps=2 - --checkpoints_total_limit=3 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, - ) - - def test_textual_inversion_checkpointing(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --train_data_dir docs/source/en/imgs - --learnable_property object - --placeholder_token - --initializer_token a - --validation_prompt - --validation_steps 1 - --save_steps 1 - --num_vectors 2 - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 3 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=1 - --checkpoints_total_limit=2 - """.split() - - run_command(self._launch_args + test_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-3"}, - ) - - def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --train_data_dir docs/source/en/imgs - --learnable_property object - --placeholder_token - --initializer_token a - --validation_prompt - --validation_steps 1 - --save_steps 1 - --num_vectors 2 - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 3 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=1 - """.split() - - run_command(self._launch_args + test_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-1", "checkpoint-2", "checkpoint-3"}, - ) - - resume_run_args = f""" - examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe - --train_data_dir docs/source/en/imgs - --learnable_property object - --placeholder_token - --initializer_token a - --validation_prompt - --validation_steps 1 - --save_steps 1 - --num_vectors 2 - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 4 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --checkpointing_steps=1 - --resume_from_checkpoint=checkpoint-3 - --checkpoints_total_limit=2 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-3", "checkpoint-4"}, - ) - - def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/instruct_pix2pix/train_instruct_pix2pix.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/instructpix2pix-10-samples - --resolution=64 - --random_flip - --train_batch_size=1 - --max_train_steps=7 - --checkpointing_steps=2 - --checkpoints_total_limit=2 - --output_dir {tmpdir} - --seed=0 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/instruct_pix2pix/train_instruct_pix2pix.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/instructpix2pix-10-samples - --resolution=64 - --random_flip - --train_batch_size=1 - --max_train_steps=9 - --checkpointing_steps=2 - --output_dir {tmpdir} - --seed=0 - """.split() - - run_command(self._launch_args + test_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - resume_run_args = f""" - examples/instruct_pix2pix/train_instruct_pix2pix.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/instructpix2pix-10-samples - --resolution=64 - --random_flip - --train_batch_size=1 - --max_train_steps=11 - --checkpointing_steps=2 - --output_dir {tmpdir} - --seed=0 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - """.split() - - run_command(self._launch_args + resume_run_args) - - # check checkpoint directories exist - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_dreambooth_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=6 - --checkpoints_total_limit=2 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=9 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - resume_run_args = f""" - examples/dreambooth/train_dreambooth.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=11 - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - """.split() - - run_command(self._launch_args + resume_run_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=6 - --checkpoints_total_limit=2 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=9 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - resume_run_args = f""" - examples/dreambooth/train_dreambooth_lora.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt=prompt - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=11 - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - """.split() - - run_command(self._launch_args + resume_run_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_controlnet_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/controlnet/train_controlnet.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/fill10 - --output_dir={tmpdir} - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=6 - --checkpoints_total_limit=2 - --checkpointing_steps=2 - --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/controlnet/train_controlnet.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/fill10 - --output_dir={tmpdir} - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet - --max_train_steps=9 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - resume_run_args = f""" - examples/controlnet/train_controlnet.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --dataset_name=hf-internal-testing/fill10 - --output_dir={tmpdir} - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet - --max_train_steps=11 - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - """.split() - - run_command(self._launch_args + resume_run_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, - ) - - def test_controlnet_sdxl(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/controlnet/train_controlnet_sdxl.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name=hf-internal-testing/fill10 - --output_dir={tmpdir} - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl - --max_train_steps=9 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) - - def test_t2i_adapter_sdxl(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/t2i_adapter/train_t2i_adapter_sdxl.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe - --adapter_model_name_or_path=hf-internal-testing/tiny-adapter - --dataset_name=hf-internal-testing/fill10 - --output_dir={tmpdir} - --resolution=64 - --train_batch_size=1 - --gradient_accumulation_steps=1 - --max_train_steps=9 - --checkpointing_steps=2 - """.split() - - run_command(self._launch_args + test_args) - - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) - - def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/custom_diffusion/train_custom_diffusion.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt= - --resolution=64 - --train_batch_size=1 - --modifier_token= - --dataloader_num_workers=0 - --max_train_steps=6 - --checkpoints_total_limit=2 - --checkpointing_steps=2 - --no_safe_serialization - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-4", "checkpoint-6"}, - ) - - def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/custom_diffusion/train_custom_diffusion.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt= - --resolution=64 - --train_batch_size=1 - --modifier_token= - --dataloader_num_workers=0 - --max_train_steps=9 - --checkpointing_steps=2 - --no_safe_serialization - """.split() - - run_command(self._launch_args + test_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, - ) - - resume_run_args = f""" - examples/custom_diffusion/train_custom_diffusion.py - --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe - --instance_data_dir=docs/source/en/imgs - --output_dir={tmpdir} - --instance_prompt= - --resolution=64 - --train_batch_size=1 - --modifier_token= - --dataloader_num_workers=0 - --max_train_steps=11 - --checkpointing_steps=2 - --resume_from_checkpoint=checkpoint-8 - --checkpoints_total_limit=3 - --no_safe_serialization - """.split() - - run_command(self._launch_args + resume_run_args) - - self.assertEqual( - {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, - ) - - def test_text_to_image_lora_sdxl(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/text_to_image/train_text_to_image_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - def test_text_to_image_lora_sdxl_with_text_encoder(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - examples/text_to_image/train_text_to_image_lora_sdxl.py - --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe - --dataset_name hf-internal-testing/dummy_image_text_data - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - --train_text_encoder - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. - keys = lora_state_dict.keys() - starts_with_unet = all( - k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys - ) - self.assertTrue(starts_with_unet) diff --git a/examples/test_examples_utils.py b/examples/test_examples_utils.py new file mode 100644 index 0000000000..3a697c65c4 --- /dev/null +++ b/examples/test_examples_utils.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import subprocess +import tempfile +import unittest +from typing import List + +from accelerate.utils import write_basic_config + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +class ExamplesTestsAccelerate(unittest.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._tmpdir = tempfile.mkdtemp() + cls.configPath = os.path.join(cls._tmpdir, "default_config.yml") + + write_basic_config(save_location=cls.configPath) + cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath] + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shutil.rmtree(cls._tmpdir) diff --git a/examples/text_to_image/test_text_to_image.py b/examples/text_to_image/test_text_to_image.py new file mode 100644 index 0000000000..308a038b55 --- /dev/null +++ b/examples/text_to_image/test_text_to_image.py @@ -0,0 +1,373 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile + +from diffusers import DiffusionPipeline, UNet2DConditionModel # noqa: E402 + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class TextToImage(ExamplesTestsAccelerate): + def test_text_to_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_text_to_image_checkpointing(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 5, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 5 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 7 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + { + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + "checkpoint-4", + "checkpoint-6", + }, + ) + + def test_text_to_image_checkpointing_use_ema(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 5, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 5 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 7 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + { + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + "checkpoint-4", + "checkpoint-6", + }, + ) + + def test_text_to_image_checkpointing_checkpoints_total_limit(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 9, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4, 6, 8 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 9 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + # resume and we should try to checkpoint at 10, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 11 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + +class TextToImageSDXL(ExamplesTestsAccelerate): + def test_text_to_image_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/text_to_image/train_text_to_image_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) diff --git a/examples/text_to_image/test_text_to_image_lora.py b/examples/text_to_image/test_text_to_image_lora.py new file mode 100644 index 0000000000..83cbb78b2d --- /dev/null +++ b/examples/text_to_image/test_text_to_image_lora.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + +from diffusers import DiffusionPipeline # noqa: E402 + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class TextToImageLoRA(ExamplesTestsAccelerate): + def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 9, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4, 6, 8 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 9 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + # resume and we should try to checkpoint at 10, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 11 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + +class TextToImageLoRASDXL(ExamplesTestsAccelerate): + def test_text_to_image_lora_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + def test_text_to_image_lora_sdxl_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. + keys = lora_state_dict.keys() + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys + ) + self.assertTrue(starts_with_unet) + + def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --train_text_encoder + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) diff --git a/examples/textual_inversion/test_textual_inversion.py b/examples/textual_inversion/test_textual_inversion.py new file mode 100644 index 0000000000..a5d7bcb65d --- /dev/null +++ b/examples/textual_inversion/test_textual_inversion.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class TextualInversion(ExamplesTestsAccelerate): + def test_textual_inversion(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors"))) + + def test_textual_inversion_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 3 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-3"}, + ) + + def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 3 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-1", "checkpoint-2", "checkpoint-3"}, + ) + + resume_run_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --resume_from_checkpoint=checkpoint-3 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-3", "checkpoint-4"}, + ) diff --git a/examples/unconditional_image_generation/test_unconditional.py b/examples/unconditional_image_generation/test_unconditional.py new file mode 100644 index 0000000000..b7e19abe9f --- /dev/null +++ b/examples/unconditional_image_generation/test_unconditional.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class Unconditional(ExamplesTestsAccelerate): + def test_train_unconditional(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 2 + --num_epochs 1 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + """.split() + + run_command(self._launch_args + test_args, return_stdout=True) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_unconditional_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + initial_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 1 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + initial_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 1 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --checkpointing_steps=1 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"}, + ) + + resume_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 2 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --resume_from_checkpoint=checkpoint-6 + --checkpointing_steps=2 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, + )