1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[ControlNet SDXL training] fixes in the training script (#4223)

* fix: #4206

* add: sdxl controlnet training smoketest.

* remove unnecessary token inits.

* add: licensing to model card.

* include SDXL licensing in the model card and make public visibility default

* debugging

* debugging

* disable local file download.

* fix: training test.

* fix: ckpt prefix.
This commit is contained in:
Sayak Paul
2023-07-25 05:31:48 +05:30
committed by GitHub
parent 95b7de88fd
commit fed12376c5
3 changed files with 29 additions and 8 deletions

View File

@@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
images.append(image)
@@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
@@ -226,6 +226,12 @@ inference: true
These are controlnet weights trained on {base_model} with new type of conditioning.
{img_str}
"""
model_card += """
## License
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -798,10 +804,7 @@ def main(args):
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
private=True,
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizers
@@ -839,7 +842,7 @@ def main(args):
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.controlnet_model_name_or_path:

View File

@@ -1296,6 +1296,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
def test_controlnet_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""

View File

@@ -751,7 +751,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down