1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[WIP] Check UNet shapes in StableDiffusionInpaintPipeline __init__ (#2853)

Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
This commit is contained in:
dg845
2023-03-28 08:53:52 -07:00
committed by GitHub
parent 25d927aa51
commit 4d0f412d0d

View File

@@ -243,6 +243,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 9:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`,"
" 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify"
" this behavior, please check whether you have loaded the right checkpoint."
)
self.register_modules(
vae=vae,