diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 078a66e4ac..6606af4f17 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -327,22 +327,6 @@ def main(): if args.seed is not None: set_seed(args.seed) - if jax.process_index() == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - rng = jax.random.PRNGKey(args.seed) if args.with_prior_preservation: