1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Optimize log_validation in train_controlnet_flax (#3110)

extract pipeline from log_validation
This commit is contained in:
Cristian Garcia
2023-04-18 07:03:00 -05:00
committed by GitHub
parent cd8b7507c2
commit 8ecdd3ef65

View File

@@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols):
return grid
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
logger.info("Running validation... ")
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
logger.info("Running validation...")
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
controlnet=controlnet,
safety_checker=None,
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
params = jax_utils.replicate(params)
params["controlnet"] = controlnet_params
pipeline_params = pipeline_params.copy()
pipeline_params["controlnet"] = controlnet_params
num_samples = jax.device_count()
prng_seed = jax.random.split(rng, jax.device_count())
@@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
images = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=params,
params=pipeline_params,
prng_seed=prng_seed,
num_inference_steps=50,
jit=True,
@@ -176,6 +167,7 @@ tags:
- text-to-image
- diffusers
- controlnet
- jax-diffusers-event
inference: true
---
"""
@@ -800,6 +792,17 @@ def main():
]:
controlnet_params[key] = unet_params[key]
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
controlnet=controlnet,
safety_checker=None,
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
pipeline_params = jax_utils.replicate(pipeline_params)
# Optimization
if args.scale_lr:
args.learning_rate = args.learning_rate * total_train_batch_size
@@ -1073,7 +1076,7 @@ def main():
and global_step % args.validation_steps == 0
and jax.process_index() == 0
):
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
_ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
@@ -1105,7 +1108,7 @@ def main():
if args.validation_prompt is not None:
if args.profile_validation:
jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.profile_validation:
jax.profiler.stop_trace()
else: