From 42d950174f5da973d3d35e55d3e1e49edf87a35b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Mar 2023 10:08:28 +0200 Subject: [PATCH] [Init] Make sure shape mismatches are caught early (#2847) Improve init --- src/diffusers/models/modeling_utils.py | 7 +++++++ tests/test_modeling_common.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index aa4e2b0ea4..5a5d233fbb 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -579,10 +579,17 @@ class ModelMixin(torch.nn.Module): " those weights or else make sure your checkpoint file is correct." ) + empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() ) + + if empty_state_dict[param_name].shape != param.shape: + raise ValueError( + f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + if accepts_dtype: set_module_tensor_to_device( model, param_name, param_device, value=param, dtype=torch_dtype diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1c45ce11b8..40aba3b249 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -100,6 +100,30 @@ class ModelUtilsTest(unittest.TestCase): diffusers.utils.import_utils._safetensors_available = True + def test_weight_overwrite(self): + with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: + UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + ) + + # make sure that error message states what keys are missing + assert "Cannot load" in str(error_context.exception) + + with tempfile.TemporaryDirectory() as tmpdirname: + model = UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + low_cpu_mem_usage=False, + ignore_mismatched_sizes=True, + ) + + assert model.config.in_channels == 9 + class ModelTesterMixin: def test_from_save_pretrained(self):