mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Init] Make sure shape mismatches are caught early (#2847)
Improve init
This commit is contained in:
committed by
GitHub
parent
81125d8499
commit
42d950174f
@@ -579,10 +579,17 @@ class ModelMixin(torch.nn.Module):
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
)
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||
|
||||
@@ -100,6 +100,30 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="unet",
|
||||
cache_dir=tmpdirname,
|
||||
in_channels=9,
|
||||
)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Cannot load" in str(error_context.exception)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="unet",
|
||||
cache_dir=tmpdirname,
|
||||
in_channels=9,
|
||||
low_cpu_mem_usage=False,
|
||||
ignore_mismatched_sizes=True,
|
||||
)
|
||||
|
||||
assert model.config.in_channels == 9
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_save_pretrained(self):
|
||||
|
||||
Reference in New Issue
Block a user