1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add imports for RL UNet

This commit is contained in:
Nathan Lambert
2022-06-20 14:35:39 -04:00
parent 9c96682a51
commit 49718b4704
3 changed files with 8 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin

View File

@@ -20,3 +20,4 @@ from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet

View File

@@ -6,6 +6,10 @@ import einops
from einops.layers.torch import Rearrange
import math
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -85,7 +89,7 @@ class ResidualTemporalBlock(nn.Module):
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
def __init__(
self,
@@ -99,7 +103,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim
self.time_mlp = nn.Sequential(