mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add score estimation model
This commit is contained in:
@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
||||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.unet_rl import TemporalUNet
|
||||
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
|
||||
@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_rl import TemporalUNet
|
||||
from .unet_sde_score_estimation import NCSNpp
|
||||
|
||||
@@ -5,6 +5,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
try:
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
|
||||
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
training_horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
self,
|
||||
training_horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
|
||||
class TemporalValue(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
1051
src/diffusers/models/unet_sde_score_estimation.py
Normal file
1051
src/diffusers/models/unet_sde_score_estimation.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user