1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script (#7258)

* fix: save_model_card utility.

* fix a little more to make it more lenient.

* remove lower()
This commit is contained in:
Sayak Paul
2024-03-08 15:22:23 +05:30
committed by GitHub
parent d9a3b69806
commit 9d9744075e

View File

@@ -114,7 +114,7 @@ def save_model_card(
)
model_description = f"""
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery />
@@ -139,7 +139,7 @@ Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
if "playgroundai" in args.pretrained_model_name_or_path:
if "playground" in base_model:
model_description += """\n
## License
@@ -148,7 +148,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
@@ -162,7 +162,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
"lora" if not use_dora else "dora",
"template:sd-lora",
]
if "playgroundai" in base_model:
if "playground" in base_model:
tags.extend(["playground", "playground-diffusers"])
else:
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
@@ -206,7 +206,7 @@ def log_validation(
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)
with inference_ctx:
@@ -1509,7 +1509,7 @@ def main(args):
if accelerator.is_main_process:
tracker_name = (
"dreambooth-lora-sd-xl"
if "playgroundai" not in args.pretrained_model_name_or_path
if "playground" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))