From fe7d1363241dfc56b986bb1287e075a35fee743c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Jun 2022 11:55:02 +0200 Subject: [PATCH] correct dict --- src/diffusers/configuration_utils.py | 55 +++++++++++++++++++++++----- src/diffusers/modeling_utils.py | 2 +- src/diffusers/pipeline_utils.py | 2 +- tests/test_modeling_utils.py | 4 ++ 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1b7b23566d..c241fd5a17 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -14,13 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixinuration base class and utilities.""" - - -import copy import inspect import json import os import re +from collections import OrderedDict from typing import Any, Dict, Tuple, Union from huggingface_hub import hf_hub_download @@ -63,10 +61,14 @@ class ConfigMixin: logger.error(f"Can't set {key} with value {value} for {self}") raise err - if not hasattr(self, "_dict_to_save"): - self._dict_to_save = {} + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") - self._dict_to_save.update(kwargs) + self._internal_dict = FrozenDict(internal_dict) def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ @@ -230,8 +232,7 @@ class ConfigMixin: @property def config(self) -> Dict[str, Any]: - output = copy.deepcopy(self._dict_to_save) - return output + return self._internal_dict def to_json_string(self) -> str: """ @@ -240,7 +241,7 @@ class ConfigMixin: Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ - config_dict = self._dict_to_save + config_dict = self._internal_dict return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): @@ -253,3 +254,39 @@ class ConfigMixin: """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + # remove `None` + args = (a for a in args if a is not None) + kwargs = {k: v for k, v in kwargs if v is not None} + + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 13a4c2efdc..2dd1b9980a 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module): revision=revision, **kwargs, ) - model.register(name_or_path=pretrained_model_name_or_path) + model.register_to_config(name_or_path=pretrained_model_name_or_path) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model pretrained_model_name_or_path = str(pretrained_model_name_or_path) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 6b49f22c48..f19e07de28 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin): def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) - model_index_dict = self.config + model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 1fa40cbd4f..d678831b20 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase): new_obj = SampleObject.from_config(tmpdirname) new_config = new_obj.config + # unfreeze configs + config = dict(config) + new_config = dict(new_config) + assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert config == new_config