diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 24b32e7f43..b25f932540 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1076,7 +1076,9 @@ def main(): and global_step % args.validation_steps == 0 and jax.process_index() == 0 ): - _ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) + _ = log_validation( + pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype + ) if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": @@ -1108,7 +1110,9 @@ def main(): if args.validation_prompt is not None: if args.profile_validation: jax.profiler.start_trace(args.output_dir) - image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) + image_logs = log_validation( + pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype + ) if args.profile_validation: jax.profiler.stop_trace() else: