diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 77ef350c51..ee960397d3 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -16,8 +16,9 @@ import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami # TODO: remove and import from diffusers.utils when the new version of diffusers is released @@ -25,7 +26,7 @@ from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): @@ -65,6 +66,12 @@ def parse_args(): default=500, help="Save learned_embeds.bin every X updates steps.", ) + parser.add_argument( + "--only_save_embeds", + action="store_true", + default=False, + help="Save only the embeddings for the new concept.", + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -596,16 +603,23 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - vae=vae, - unet=unet, - revision=args.revision, - ) - pipeline.save_pretrained(args.output_dir) - # Also save the newly trained embeddings + if args.push_to_hub and args.only_save_embeds: + logger.warn("Enabling full model saving because --push_to_hub=True was specified.") + save_full_model = True + else: + save_full_model = not args.only_save_embeds + if save_full_model: + pipeline = StableDiffusionPipeline( + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_dir) + # Save the newly trained embeddings save_path = os.path.join(args.output_dir, "learned_embeds.bin") save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)