1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

correct dict

This commit is contained in:
Patrick von Platen
2022-06-17 11:55:02 +02:00
parent e660a05fed
commit fe7d136324
4 changed files with 52 additions and 11 deletions

View File

@@ -14,13 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import copy
import inspect
import json
import os
import re
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
from huggingface_hub import hf_hub_download
@@ -63,10 +61,14 @@ class ConfigMixin:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_dict_to_save"):
self._dict_to_save = {}
if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
previous_dict = dict(self._internal_dict)
internal_dict = {**self._internal_dict, **kwargs}
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
self._dict_to_save.update(kwargs)
self._internal_dict = FrozenDict(internal_dict)
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
@@ -230,8 +232,7 @@ class ConfigMixin:
@property
def config(self) -> Dict[str, Any]:
output = copy.deepcopy(self._dict_to_save)
return output
return self._internal_dict
def to_json_string(self) -> str:
"""
@@ -240,7 +241,7 @@ class ConfigMixin:
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._dict_to_save
config_dict = self._internal_dict
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
@@ -253,3 +254,39 @@ class ConfigMixin:
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
# remove `None`
args = (a for a in args if a is not None)
kwargs = {k: v for k, v in kwargs if v is not None}
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)

View File

@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module):
revision=revision,
**kwargs,
)
model.register(name_or_path=pretrained_model_name_or_path)
model.register_to_config(name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

View File

@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin):
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = self.config
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module")

View File

@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase):
new_obj = SampleObject.from_config(tmpdirname)
new_config = new_obj.config
# unfreeze configs
config = dict(config)
new_config = dict(new_config)
assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config