1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add score estimation model

This commit is contained in:
Patrick von Platen
2022-06-24 15:55:26 +00:00
parent bd9c9fbfbe
commit ac796924df
4 changed files with 1070 additions and 19 deletions

View File

@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin

View File

@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp

View File

@@ -5,6 +5,7 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
self,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
):
super().__init__()

File diff suppressed because it is too large Load Diff