mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Optimize log_validation in train_controlnet_flax (#3110)
extract pipeline from log_validation
This commit is contained in:
@@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols):
|
||||
return grid
|
||||
|
||||
|
||||
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation... ")
|
||||
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation...")
|
||||
|
||||
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
params = jax_utils.replicate(params)
|
||||
params["controlnet"] = controlnet_params
|
||||
pipeline_params = pipeline_params.copy()
|
||||
pipeline_params["controlnet"] = controlnet_params
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prng_seed = jax.random.split(rng, jax.device_count())
|
||||
@@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
|
||||
images = pipeline(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=params,
|
||||
params=pipeline_params,
|
||||
prng_seed=prng_seed,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
@@ -176,6 +167,7 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- controlnet
|
||||
- jax-diffusers-event
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
@@ -800,6 +792,17 @@ def main():
|
||||
]:
|
||||
controlnet_params[key] = unet_params[key]
|
||||
|
||||
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
pipeline_params = jax_utils.replicate(pipeline_params)
|
||||
|
||||
# Optimization
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * total_train_batch_size
|
||||
@@ -1073,7 +1076,7 @@ def main():
|
||||
and global_step % args.validation_steps == 0
|
||||
and jax.process_index() == 0
|
||||
):
|
||||
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
_ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
|
||||
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
|
||||
if args.report_to == "wandb":
|
||||
@@ -1105,7 +1108,7 @@ def main():
|
||||
if args.validation_prompt is not None:
|
||||
if args.profile_validation:
|
||||
jax.profiler.start_trace(args.output_dir)
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
if args.profile_validation:
|
||||
jax.profiler.stop_trace()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user