mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[WIP] Check UNet shapes in StableDiffusionInpaintPipeline __init__ (#2853)
Add warning in __init__ if user loads a checkpoint with pipeline.unet.config.in_channels other than 9.
This commit is contained in:
@@ -243,6 +243,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
|
||||
if unet.config.in_channels != 9:
|
||||
logger.warning(
|
||||
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
|
||||
f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`,"
|
||||
" 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify"
|
||||
" this behavior, please check whether you have loaded the right checkpoint."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
|
||||
Reference in New Issue
Block a user