mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make Diffwave subclass of ModelMixin
This commit is contained in:
@@ -19,6 +19,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
@@ -209,14 +211,35 @@ class ResidualGroup(nn.Module):
|
||||
return skip * math.sqrt(1.0 / self.num_res_layers)
|
||||
|
||||
|
||||
class DiffWave(nn.Module):
|
||||
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
|
||||
num_res_layers, dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out):
|
||||
class DiffWave(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
res_channels=128,
|
||||
skip_channels=128,
|
||||
out_channels=1,
|
||||
num_res_layers=30,
|
||||
dilation_cycle=10,
|
||||
diffusion_step_embed_dim_in=128,
|
||||
diffusion_step_embed_dim_mid=512,
|
||||
diffusion_step_embed_dim_out=512,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all init arguments with self.register
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
res_channels=res_channels,
|
||||
skip_channels=skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_layers=num_res_layers,
|
||||
dilation_cycle=dilation_cycle,
|
||||
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
||||
)
|
||||
|
||||
|
||||
# Initial conv1x1 with relu
|
||||
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
|
||||
# All residual layers
|
||||
|
||||
Reference in New Issue
Block a user