mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add imports for RL UNet
This commit is contained in:
@@ -9,6 +9,7 @@ from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.unet_rl import TemporalUNet
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
|
||||
@@ -20,3 +20,4 @@ from .unet import UNetModel
|
||||
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_rl import TemporalUNet
|
||||
@@ -6,6 +6,10 @@ import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
import math
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -85,7 +89,7 @@ class ResidualTemporalBlock(nn.Module):
|
||||
out = self.blocks[1](out)
|
||||
return out + self.residual_conv(x)
|
||||
|
||||
class TemporalUnet(nn.Module):
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -99,7 +103,7 @@ class TemporalUnet(nn.Module):
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user