diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dc69d8bf35..6339eb7c0a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,6 +9,7 @@ from .models.unet import UNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_ldm import UNetLDMModel +from .models.unet_rl import TemporalUNet from .pipeline_utils import DiffusionPipeline from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1a657f224e..5b1a46198e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -20,3 +20,4 @@ from .unet import UNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel +from .unet_rl import TemporalUNet \ No newline at end of file diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 2a0b441c8f..973828105d 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -6,6 +6,10 @@ import einops from einops.layers.torch import Rearrange import math +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + + class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -85,7 +89,7 @@ class ResidualTemporalBlock(nn.Module): out = self.blocks[1](out) return out + self.residual_conv(x) -class TemporalUnet(nn.Module): +class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): def __init__( self, @@ -99,7 +103,7 @@ class TemporalUnet(nn.Module): dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - print(f'[ models/temporal ] Channel dimensions: {in_out}') + # print(f'[ models/temporal ] Channel dimensions: {in_out}') time_dim = dim self.time_mlp = nn.Sequential(