diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ad43ee7aee..158d03185a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -46,6 +46,7 @@ from diffusers import ( DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler @@ -62,7 +63,15 @@ check_min_version("0.17.0.dev0") logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -74,8 +83,8 @@ license: creativeml-openrail-m base_model: {base_model} instance_prompt: {prompt} tags: -- stable-diffusion -- stable-diffusion-diffusers +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} - text-to-image - diffusers - dreambooth @@ -1297,6 +1306,7 @@ def main(args): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, + pipeline=pipeline, ) upload_folder( repo_id=repo_id, diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e640542e36..4ff759dcd6 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -68,7 +68,15 @@ check_min_version("0.17.0.dev0") logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -80,8 +88,8 @@ license: creativeml-openrail-m base_model: {base_model} instance_prompt: {prompt} tags: -- stable-diffusion -- stable-diffusion-diffusers +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} - text-to-image - diffusers - lora @@ -844,7 +852,7 @@ def main(args): hidden_size=module.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = StableDiffusionPipeline.from_pretrained( + temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=text_encoder ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) @@ -1332,6 +1340,7 @@ def main(args): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, + pipeline=pipeline, ) upload_folder( repo_id=repo_id,