mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script (#7258)
* fix: save_model_card utility. * fix a little more to make it more lenient. * remove lower()
This commit is contained in:
@@ -114,7 +114,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -139,7 +139,7 @@ Weights for this model are available in Safetensors format.
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
if "playgroundai" in args.pretrained_model_name_or_path:
|
||||
if "playground" in base_model:
|
||||
model_description += """\n
|
||||
## License
|
||||
|
||||
@@ -148,7 +148,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
|
||||
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -162,7 +162,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
"lora" if not use_dora else "dora",
|
||||
"template:sd-lora",
|
||||
]
|
||||
if "playgroundai" in base_model:
|
||||
if "playground" in base_model:
|
||||
tags.extend(["playground", "playground-diffusers"])
|
||||
else:
|
||||
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
|
||||
@@ -206,7 +206,7 @@ def log_validation(
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with inference_ctx:
|
||||
@@ -1509,7 +1509,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
tracker_name = (
|
||||
"dreambooth-lora-sd-xl"
|
||||
if "playgroundai" not in args.pretrained_model_name_or_path
|
||||
if "playground" not in args.pretrained_model_name_or_path
|
||||
else "dreambooth-lora-playground"
|
||||
)
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
Reference in New Issue
Block a user