mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples/DreamBooth] refactor save_model_card utility in dreambooth examples (#3543)
refactor save_model_card utility in dreambooth examples.
This commit is contained in:
@@ -46,6 +46,7 @@ from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
@@ -62,7 +63,15 @@ check_min_version("0.17.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
prompt=str,
|
||||
repo_folder=None,
|
||||
pipeline: DiffusionPipeline = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
@@ -74,8 +83,8 @@ license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
|
||||
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- dreambooth
|
||||
@@ -1297,6 +1306,7 @@ def main(args):
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -68,7 +68,15 @@ check_min_version("0.17.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
prompt=str,
|
||||
repo_folder=None,
|
||||
pipeline: DiffusionPipeline = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
@@ -80,8 +88,8 @@ license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
|
||||
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
@@ -844,7 +852,7 @@ def main(args):
|
||||
hidden_size=module.out_features, cross_attention_dim=None
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||
)
|
||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||
@@ -1332,6 +1340,7 @@ def main(args):
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
Reference in New Issue
Block a user