diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 6c14e8ca10..1ac3452828 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -106,7 +106,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler image_logs = [] for validation_prompt, validation_image in zip(validation_prompts, validation_images): - validation_image = Image.open(validation_image) + validation_image = Image.open(validation_image).convert('RGB') images = [] diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index f409a53966..dab6864b07 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -110,7 +110,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) - validation_image = Image.open(validation_image) + validation_image = Image.open(validation_image).convert('RGB') processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) images = pipeline(