1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

keep _use_default_values as a list type (#4040)

Signed-off-by: Raphael <oOraph@users.noreply.github.com>
Co-authored-by: Raphael <oOraph@users.noreply.github.com>
This commit is contained in:
oOraph
2023-07-11 18:21:08 +02:00
committed by Patrick von Platen
parent a2fa787121
commit c4402daff1
2 changed files with 3 additions and 3 deletions

View File

@@ -607,7 +607,7 @@ def register_to_config(init):
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
@@ -655,7 +655,7 @@ def flax_register_to_config(cls):
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)

View File

@@ -264,7 +264,7 @@ class ConfigTester(unittest.TestCase):
config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")}
# make sure that default config has all keys in `_use_default_values`
assert set(config_dict.keys()) == config.config._use_default_values
assert set(config_dict.keys()) == set(config.config._use_default_values)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_config(tmpdirname)