From 4d0f412d0d5a55d7d653cecfe6cf770a7f1af277 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Tue, 28 Mar 2023 08:53:52 -0700 Subject: [PATCH] [WIP] Check UNet shapes in StableDiffusionInpaintPipeline __init__ (#2853) Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9. --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 199325236c..a934f639a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -243,6 +243,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.warning( + f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," + f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`," + " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify" + " this behavior, please check whether you have loaded the right checkpoint." + ) self.register_modules( vae=vae,