mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Let's make sure that dreambooth always uploads to the Hub (#3272)
* Update Dreambooth README * Adapt all docs as well * automatically write model card * fix * make style
This commit is contained in:
committed by
Daniel Gu
parent
ffe6e92577
commit
10d856a94c
@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=400
|
||||
--max_train_steps=400 \
|
||||
--push_to_hub
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### 12GB GPU
|
||||
@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### 8 GB GPU
|
||||
@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800 \
|
||||
--mixed_precision=fp16
|
||||
--mixed_precision=fp16 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=400
|
||||
--max_train_steps=400 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Training with prior-preservation loss
|
||||
@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Fine-tune text encoder with the UNet.
|
||||
@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
--max_train_steps=800 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Using DreamBooth for pipelines other than Stable Diffusion
|
||||
|
||||
@@ -61,6 +61,39 @@ check_min_version("0.17.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- dreambooth
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# DreamBooth - {repo_id}
|
||||
|
||||
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
|
||||
You can find some example images in the following. \n
|
||||
{img_str}
|
||||
|
||||
DreamBooth for the text encoder was enabled: {train_text_encoder}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
@@ -997,13 +1032,16 @@ def main(args):
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
images = []
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
||||
images = log_validation(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
@@ -1024,6 +1062,14 @@ def main(args):
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
Reference in New Issue
Block a user