1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make Diffwave subclass of ModelMixin

This commit is contained in:
patil-suraj
2022-06-13 14:41:09 +02:00
parent 86da45bc66
commit bc72d297c6

View File

@@ -19,6 +19,8 @@ import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
from ..pipeline_utils import DiffusionPipeline
@@ -209,14 +211,35 @@ class ResidualGroup(nn.Module):
return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module):
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
class DiffWave(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels=1,
res_channels=128,
skip_channels=128,
out_channels=1,
num_res_layers=30,
dilation_cycle=10,
diffusion_step_embed_dim_in=128,
diffusion_step_embed_dim_mid=512,
diffusion_step_embed_dim_out=512,
):
super().__init__()
# register all init arguments with self.register
self.register(
in_channels=in_channels,
res_channels=res_channels,
skip_channels=skip_channels,
out_channels=out_channels,
num_res_layers=num_res_layers,
dilation_cycle=dilation_cycle,
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers