mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[ControlNet SDXL training] fixes in the training script (#4223)
* fix: #4206 * add: sdxl controlnet training smoketest. * remove unnecessary token inits. * add: licensing to model card. * include SDXL licensing in the model card and make public visibility default * debugging * debugging * disable local file download. * fix: training test. * fix: ckpt prefix.
This commit is contained in:
@@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
validation_prompt, validation_image, num_inference_steps=20, generator=generator
|
||||
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -226,6 +226,12 @@ inference: true
|
||||
|
||||
These are controlnet weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
"""
|
||||
model_card += """
|
||||
|
||||
## License
|
||||
|
||||
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -798,10 +804,7 @@ def main(args):
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
private=True,
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizers
|
||||
@@ -839,7 +842,7 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
if args.controlnet_model_name_or_path:
|
||||
|
||||
@@ -1296,6 +1296,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
|
||||
def test_controlnet_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_sdxl.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -751,7 +751,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
|
||||
Reference in New Issue
Block a user