diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d68cef7040..6aef806075 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -208,6 +208,7 @@ class ConfigMixin: def extract_init_dict(cls, config_dict, **kwargs): expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") + expected_keys.remove("kwargs") init_dict = {} for key in expected_keys: if key in kwargs: diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 9e07c2ff1c..4d4bbbdd7b 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module): models, `pixel_values` for vision models and `input_values` for speech models). """ config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"] def __init__(self): super().__init__() diff --git a/src/diffusers/models/unet_conditional.py b/src/diffusers/models/unet_conditional.py index a034e3f81b..ff24a4fb59 100644 --- a/src/diffusers/models/unet_conditional.py +++ b/src/diffusers/models/unet_conditional.py @@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): mid_block_scale_factor=1, center_input_sample=False, resnet_num_groups=30, + **kwargs, ): super().__init__() + # remove automatically added kwargs + for arg in self._automatically_saved_args: + kwargs.pop(arg, None) + + if len(kwargs) > 0: + raise ValueError( + f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}" + ) + # register all __init__ params to be accessible via `self.config.<...>` # should probably be automated down the road as this is pure boiler plate code self.register_to_config( diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index 610f712565..34ea9a920e 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): mid_block_scale_factor=1, center_input_sample=False, resnet_num_groups=32, + **kwargs, ): super().__init__() + # remove automatically added kwargs + for arg in self._automatically_saved_args: + kwargs.pop(arg, None) + + if len(kwargs) > 0: + raise ValueError( + f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}" + ) + # register all __init__ params to be accessible via `self.config.<...>` # should probably be automated down the road as this is pure boiler plate code self.register_to_config(