mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -7,5 +7,6 @@ __version__ = "0.0.1"
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import UNetGLIDEModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
|
||||
@@ -18,3 +18,4 @@
|
||||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import UNetGLIDEModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
|
||||
@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype_ = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x.type(self.dtype_)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
Reference in New Issue
Block a user