mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add a decorator for register_to_config (#108)
* Add a decorator for register_to_config * All models and test
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ConfigMixinuration base class and utilities."""
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -295,3 +296,46 @@ class FrozenDict(OrderedDict):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setitem__(name, value)
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
"""
|
||||
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
|
||||
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
|
||||
registered in the config, use the `ignore_for_config` class variable
|
||||
|
||||
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
||||
"""
|
||||
|
||||
@functools.wraps(init)
|
||||
def inner_init(self, *args, **kwargs):
|
||||
# Ignore private kwargs in the init.
|
||||
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||
init(self, *args, **init_kwargs)
|
||||
if not isinstance(self, ConfigMixin):
|
||||
raise RuntimeError(
|
||||
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
||||
"not inherit from `ConfigMixin`."
|
||||
)
|
||||
|
||||
ignore = getattr(self, "ignore_for_config", [])
|
||||
# Get positional arguments aligned with kwargs
|
||||
new_kwargs = {}
|
||||
signature = inspect.signature(init)
|
||||
parameters = {
|
||||
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
||||
}
|
||||
for arg, name in zip(args, parameters.keys()):
|
||||
new_kwargs[name] = arg
|
||||
|
||||
# Then add all kwargs
|
||||
new_kwargs.update(
|
||||
{
|
||||
k: init_kwargs.get(k, default)
|
||||
for k, default in parameters.items()
|
||||
if k not in ignore and k not in new_kwargs
|
||||
}
|
||||
)
|
||||
getattr(self, "register_to_config")(**new_kwargs)
|
||||
|
||||
return inner_init
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
@@ -33,6 +33,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
image_size=None,
|
||||
@@ -63,40 +64,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=30,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# remove automatically added kwargs
|
||||
for arg in self._automatically_saved_args:
|
||||
kwargs.pop(arg, None)
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
|
||||
)
|
||||
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
# should probably be automated down the road as this is pure boiler plate code
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
block_channels=block_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
down_blocks=down_blocks,
|
||||
up_blocks=up_blocks,
|
||||
dropout=dropout,
|
||||
resnet_eps=resnet_eps,
|
||||
conv_resample=conv_resample,
|
||||
num_head_channels=num_head_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
downscale_freq_shift=downscale_freq_shift,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
resnet_num_groups=resnet_num_groups,
|
||||
center_input_sample=center_input_sample,
|
||||
)
|
||||
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -33,6 +33,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
image_size=None,
|
||||
@@ -59,41 +60,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# remove automatically added kwargs
|
||||
for arg in self._automatically_saved_args:
|
||||
kwargs.pop(arg, None)
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
|
||||
)
|
||||
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
# should probably be automated down the road as this is pure boiler plate code
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
block_channels=block_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
down_blocks=down_blocks,
|
||||
up_blocks=up_blocks,
|
||||
dropout=dropout,
|
||||
resnet_eps=resnet_eps,
|
||||
conv_resample=conv_resample,
|
||||
num_head_channels=num_head_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
downscale_freq_shift=downscale_freq_shift,
|
||||
time_embedding_type=time_embedding_type,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
resnet_num_groups=resnet_num_groups,
|
||||
center_input_sample=center_input_sample,
|
||||
)
|
||||
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
@@ -380,6 +380,7 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
|
||||
class VQModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
@@ -399,27 +400,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
n_embed=n_embed,
|
||||
embed_dim=embed_dim,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
double_z=double_z,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
@@ -478,6 +458,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
@@ -496,26 +477,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
embed_dim=embed_dim,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
double_z=double_z,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
)
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
@@ -60,16 +61,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
clip_sample=clip_sample,
|
||||
)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
@@ -60,17 +61,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
variance_type=variance_type,
|
||||
clip_sample=clip_sample,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
||||
|
||||
class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
@@ -56,13 +57,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule="linear",
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"np" or "pt" for the expected format of samples passed to the Scheduler.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=2000,
|
||||
@@ -47,15 +48,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
correct_steps=1,
|
||||
tensor_format="pt",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
snr=snr,
|
||||
sigma_min=sigma_min,
|
||||
sigma_max=sigma_max,
|
||||
sampling_eps=sampling_eps,
|
||||
correct_steps=correct_steps,
|
||||
)
|
||||
# self.sigmas = None
|
||||
# self.discrete_sigmas = None
|
||||
#
|
||||
|
||||
@@ -19,19 +19,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
beta_min=beta_min,
|
||||
beta_max=beta_max,
|
||||
sampling_eps=sampling_eps,
|
||||
)
|
||||
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
|
||||
@@ -23,6 +23,7 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
class SchedulerMixin:
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
ignore_for_config = ["tensor_format"]
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
self.tensor_format = tensor_format
|
||||
|
||||
@@ -18,6 +18,7 @@ import inspect
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
from atexit import register
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -38,7 +39,7 @@ from diffusers import (
|
||||
UNetUnconditionalModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
@@ -47,25 +48,63 @@ from diffusers.training_utils import EMAModel
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class SampleObject(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_load_not_from_mixin(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ConfigMixin.from_config("dummy_path")
|
||||
|
||||
def test_register_to_config(self):
|
||||
obj = SampleObject()
|
||||
config = obj.config
|
||||
assert config["a"] == 2
|
||||
assert config["b"] == 5
|
||||
assert config["c"] == (2, 5)
|
||||
assert config["d"] == "for diffusion"
|
||||
assert config["e"] == [1, 3]
|
||||
|
||||
# init ignore private arguments
|
||||
obj = SampleObject(_name_or_path="lalala")
|
||||
config = obj.config
|
||||
assert config["a"] == 2
|
||||
assert config["b"] == 5
|
||||
assert config["c"] == (2, 5)
|
||||
assert config["d"] == "for diffusion"
|
||||
assert config["e"] == [1, 3]
|
||||
|
||||
# can override default
|
||||
obj = SampleObject(c=6)
|
||||
config = obj.config
|
||||
assert config["a"] == 2
|
||||
assert config["b"] == 5
|
||||
assert config["c"] == 6
|
||||
assert config["d"] == "for diffusion"
|
||||
assert config["e"] == [1, 3]
|
||||
|
||||
# can use positional arguments.
|
||||
obj = SampleObject(1, c=6)
|
||||
config = obj.config
|
||||
assert config["a"] == 1
|
||||
assert config["b"] == 5
|
||||
assert config["c"] == 6
|
||||
assert config["d"] == "for diffusion"
|
||||
assert config["e"] == [1, 3]
|
||||
|
||||
def test_save_load(self):
|
||||
class SampleObject(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
self.register_to_config(a=a, b=b, c=c, d=d, e=e)
|
||||
|
||||
obj = SampleObject()
|
||||
config = obj.config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user