1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Lora] Model card (#2032)

* [Lora] up lora training

* finish

* finish

* finish model card
This commit is contained in:
Patrick von Platen
2023-01-19 09:44:02 +01:00
committed by GitHub
parent 3c07840b1b
commit 007c914c70

View File

@@ -58,6 +58,34 @@ check_min_version("0.12.0.dev0")
logger = get_logger(__name__)
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_name}
These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
@@ -913,34 +941,42 @@ def main(args):
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
save_model_card(
repo_name,
images=images,
base_model=args.pretrained_model_name_or_path,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()