mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Lora] Model card (#2032)
* [Lora] up lora training * finish * finish * finish model card
This commit is contained in:
committed by
GitHub
parent
3c07840b1b
commit
007c914c70
@@ -58,6 +58,34 @@ check_min_version("0.12.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA DreamBooth - {repo_name}
|
||||
|
||||
These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -913,34 +941,42 @@ def main(args):
|
||||
unet = unet.to(torch.float32)
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
prompt = args.num_validation_images * [args.validation_prompt]
|
||||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
prompt = args.num_validation_images * [args.validation_prompt]
|
||||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
save_model_card(
|
||||
repo_name,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user