From b8bfef2ab94c423875158076aec481d8a65b7bfa Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Mar 2023 19:11:45 +0100 Subject: [PATCH] make style --- .../textual_inversion_flax.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index e184eb0c0b..e988a25526 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -433,9 +433,15 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision) - vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision) - unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision) + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) # Create sampling rng rng = jax.random.PRNGKey(args.seed) @@ -633,11 +639,13 @@ def main(): if global_step >= args.max_train_steps: break if global_step % args.save_steps == 0: - learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ - placeholder_token_id - ] + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][ + "embedding" + ][placeholder_token_id] learned_embeds_dict = {args.placeholder_token: learned_embeds} - jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict) + jnp.save( + os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict + ) train_metric = jax_utils.unreplicate(train_metric)