diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 05145b97c0..a59a1e7988 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixinuration base class and utilities.""" +import functools import inspect import json import os @@ -295,3 +296,46 @@ class FrozenDict(OrderedDict): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setitem__(name, value) + + +def register_to_config(init): + """ + Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically + sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be + registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + init(self, *args, **init_kwargs) + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + getattr(self, "register_to_config")(**new_kwargs) + + return inner_init diff --git a/src/diffusers/models/unet_conditional.py b/src/diffusers/models/unet_conditional.py index ff24a4fb59..fe25e2baea 100644 --- a/src/diffusers/models/unet_conditional.py +++ b/src/diffusers/models/unet_conditional.py @@ -3,7 +3,7 @@ from typing import Dict, Union import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block @@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): increased efficiency. """ + @register_to_config def __init__( self, image_size=None, @@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): mid_block_scale_factor=1, center_input_sample=False, resnet_num_groups=30, - **kwargs, ): - super().__init__() - # remove automatically added kwargs - for arg in self._automatically_saved_args: - kwargs.pop(arg, None) - - if len(kwargs) > 0: - raise ValueError( - f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}" - ) - - # register all __init__ params to be accessible via `self.config.<...>` - # should probably be automated down the road as this is pure boiler plate code - self.register_to_config( - image_size=image_size, - in_channels=in_channels, - block_channels=block_channels, - downsample_padding=downsample_padding, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - down_blocks=down_blocks, - up_blocks=up_blocks, - dropout=dropout, - resnet_eps=resnet_eps, - conv_resample=conv_resample, - num_head_channels=num_head_channels, - flip_sin_to_cos=flip_sin_to_cos, - downscale_freq_shift=downscale_freq_shift, - mid_block_scale_factor=mid_block_scale_factor, - resnet_num_groups=resnet_num_groups, - center_input_sample=center_input_sample, - ) - self.image_size = image_size time_embed_dim = block_channels[0] * 4 diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index 34ea9a920e..dcd673abd8 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -3,7 +3,7 @@ from typing import Dict, Union import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): increased efficiency. """ + @register_to_config def __init__( self, image_size=None, @@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): mid_block_scale_factor=1, center_input_sample=False, resnet_num_groups=32, - **kwargs, ): - super().__init__() - # remove automatically added kwargs - for arg in self._automatically_saved_args: - kwargs.pop(arg, None) - - if len(kwargs) > 0: - raise ValueError( - f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}" - ) - - # register all __init__ params to be accessible via `self.config.<...>` - # should probably be automated down the road as this is pure boiler plate code - self.register_to_config( - image_size=image_size, - in_channels=in_channels, - block_channels=block_channels, - downsample_padding=downsample_padding, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - down_blocks=down_blocks, - up_blocks=up_blocks, - dropout=dropout, - resnet_eps=resnet_eps, - conv_resample=conv_resample, - num_head_channels=num_head_channels, - flip_sin_to_cos=flip_sin_to_cos, - downscale_freq_shift=downscale_freq_shift, - time_embedding_type=time_embedding_type, - mid_block_scale_factor=mid_block_scale_factor, - resnet_num_groups=resnet_num_groups, - center_input_sample=center_input_sample, - ) - self.image_size = image_size time_embed_dim = block_channels[0] * 4 diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b3ca0d67a6..58d85800f6 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -2,7 +2,7 @@ import numpy as np import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .resnet import Downsample2D, ResnetBlock2D, Upsample2D @@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object): class VQModel(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, ch, @@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin): resamp_with_conv=True, give_pre_end=False, ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) # pass init params to Encoder self.encoder = Encoder( @@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, ch, @@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin): resamp_with_conv=True, give_pre_end=False, ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) # pass init params to Encoder self.encoder = Encoder( diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 1b8d9f0761..e7e4955780 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -21,7 +21,7 @@ from typing import Union import numpy as np import torch -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin @@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class DDIMScheduler(SchedulerMixin, ConfigMixin): + @register_to_config def __init__( self, num_train_timesteps=1000, @@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample=True, tensor_format="np", ): - super().__init__() - self.register_to_config( - num_train_timesteps=num_train_timesteps, - beta_start=beta_start, - beta_end=beta_end, - beta_schedule=beta_schedule, - trained_betas=trained_betas, - timestep_values=timestep_values, - clip_sample=clip_sample, - ) if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 25eae068a9..979d0f4f34 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -20,7 +20,7 @@ from typing import Union import numpy as np import torch -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin @@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class DDPMScheduler(SchedulerMixin, ConfigMixin): + @register_to_config def __init__( self, num_train_timesteps=1000, @@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): clip_sample=True, tensor_format="np", ): - super().__init__() - self.register_to_config( - num_train_timesteps=num_train_timesteps, - beta_start=beta_start, - beta_end=beta_end, - beta_schedule=beta_schedule, - trained_betas=trained_betas, - timestep_values=timestep_values, - variance_type=variance_type, - clip_sample=clip_sample, - ) if trained_betas is not None: self.betas = np.asarray(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 3b889d0ac2..216c4a715f 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -20,7 +20,7 @@ from typing import Union import numpy as np import torch -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin @@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): class PNDMScheduler(SchedulerMixin, ConfigMixin): + @register_to_config def __init__( self, num_train_timesteps=1000, @@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): beta_schedule="linear", tensor_format="np", ): - super().__init__() - self.register_to_config( - num_train_timesteps=num_train_timesteps, - beta_start=beta_start, - beta_end=beta_end, - beta_schedule=beta_schedule, - ) if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index bf9a22d93e..2f21faa2bf 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -21,7 +21,7 @@ from typing import Union import numpy as np import torch -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin @@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): "np" or "pt" for the expected format of samples passed to the Scheduler. """ + @register_to_config def __init__( self, num_train_timesteps=2000, @@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): correct_steps=1, tensor_format="pt", ): - super().__init__() - self.register_to_config( - num_train_timesteps=num_train_timesteps, - snr=snr, - sigma_min=sigma_min, - sigma_max=sigma_max, - sampling_eps=sampling_eps, - correct_steps=correct_steps, - ) # self.sigmas = None # self.discrete_sigmas = None # diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 08f1c2af0c..d24aeaea05 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -19,19 +19,13 @@ import numpy as np import torch -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): - super().__init__() - self.register_to_config( - num_train_timesteps=num_train_timesteps, - beta_min=beta_min, - beta_max=beta_max, - sampling_eps=sampling_eps, - ) self.sigmas = None self.discrete_sigmas = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 9a66d2405c..7f573cc24b 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["tensor_format"] def set_format(self, tensor_format="pt"): self.tensor_format = tensor_format diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4249007736..dc7f125476 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -18,6 +18,7 @@ import inspect import math import tempfile import unittest +from atexit import register import numpy as np import torch @@ -38,7 +39,7 @@ from diffusers import ( UNetUnconditionalModel, VQModel, ) -from diffusers.configuration_utils import ConfigMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.training_utils import EMAModel @@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel torch.backends.cuda.matmul.allow_tf32 = False +class SampleObject(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + pass + + class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): ConfigMixin.from_config("dummy_path") + def test_register_to_config(self): + obj = SampleObject() + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # init ignore private arguments + obj = SampleObject(_name_or_path="lalala") + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # can override default + obj = SampleObject(c=6) + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == 6 + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # can use positional arguments. + obj = SampleObject(1, c=6) + config = obj.config + assert config["a"] == 1 + assert config["b"] == 5 + assert config["c"] == 6 + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + def test_save_load(self): - class SampleObject(ConfigMixin): - config_name = "config.json" - - def __init__( - self, - a=2, - b=5, - c=(2, 5), - d="for diffusion", - e=[1, 3], - ): - self.register_to_config(a=a, b=b, c=c, d=d, e=e) - obj = SampleObject() config = obj.config