1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-10-20 20:19:07 -10:00
parent 5d30c5bd00
commit 231b316fc0
3 changed files with 6 additions and 5 deletions

View File

@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.

View File

@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

View File

@@ -22,6 +22,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
@dataclass
@@ -37,7 +38,7 @@ class VQEncoderOutput(BaseOutput):
latents: torch.Tensor
class VQModel(ModelMixin, ConfigMixin):
class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VQ-VAE model for decoding latent representations.