1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add a decorator for register_to_config (#108)

* Add a decorator for register_to_config

* All models and test
This commit is contained in:
Sylvain Gugger
2022-07-20 15:42:50 +02:00
committed by GitHub
parent 7e11392dfd
commit ad9d252596
11 changed files with 115 additions and 174 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import functools
import inspect
import json
import os
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
def register_to_config(init):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg
# Then add all kwargs
new_kwargs.update(
{
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
}
)
getattr(self, "register_to_config")(**new_kwargs)
return inner_init

View File

@@ -3,7 +3,7 @@ from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@register_to_config
def __init__(
self,
image_size=None,
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=30,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.image_size = image_size
time_embed_dim = block_channels[0] * 4

View File

@@ -3,7 +3,7 @@ from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency.
"""
@register_to_config
def __init__(
self,
image_size=None,
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=32,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_channels=block_channels,
downsample_padding=downsample_padding,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
resnet_eps=resnet_eps,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
time_embedding_type=time_embedding_type,
mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample,
)
self.image_size = image_size
time_embed_dim = block_channels[0] * 4

View File

@@ -2,7 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv=True,
give_pre_end=False,
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
n_embed=n_embed,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
ch_mult=ch_mult,
dropout=dropout,
double_z=double_z,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
# pass init params to Encoder
self.encoder = Encoder(
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv=True,
give_pre_end=False,
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
ch_mult=ch_mult,
dropout=dropout,
double_z=double_z,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
# pass init params to Encoder
self.encoder = Encoder(

View File

@@ -21,7 +21,7 @@ from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample=True,
tensor_format="np",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)

View File

@@ -20,7 +20,7 @@ from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample=True,
tensor_format="np",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type,
clip_sample=clip_sample,
)
if trained_betas is not None:
self.betas = np.asarray(trained_betas)

View File

@@ -20,7 +20,7 @@ from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule="linear",
tensor_format="np",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)

View File

@@ -21,7 +21,7 @@ from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
@register_to_config
def __init__(
self,
num_train_timesteps=2000,
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps=1,
tensor_format="pt",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
sampling_eps=sampling_eps,
correct_steps=correct_steps,
)
# self.sigmas = None
# self.discrete_sigmas = None
#

View File

@@ -19,19 +19,13 @@
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
)
self.sigmas = None
self.discrete_sigmas = None

View File

@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["tensor_format"]
def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format

View File

@@ -18,6 +18,7 @@ import inspect
import math
import tempfile
import unittest
from atexit import register
import numpy as np
import torch
@@ -38,7 +39,7 @@ from diffusers import (
UNetUnconditionalModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel
torch.backends.cuda.matmul.allow_tf32 = False
class SampleObject(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
pass
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path")
def test_register_to_config(self):
obj = SampleObject()
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == (2, 5)
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
# init ignore private arguments
obj = SampleObject(_name_or_path="lalala")
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == (2, 5)
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
# can override default
obj = SampleObject(c=6)
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == 6
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
# can use positional arguments.
obj = SampleObject(1, c=6)
config = obj.config
assert config["a"] == 1
assert config["b"] == 5
assert config["c"] == 6
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
def test_save_load(self):
class SampleObject(ConfigMixin):
config_name = "config.json"
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
self.register_to_config(a=a, b=b, c=c, d=d, e=e)
obj = SampleObject()
config = obj.config