diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index 2f9e75623e..ae8a118d71 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -166,12 +166,12 @@ class VQModel(ModelMixin, ConfigMixin): Args: sample (`torch.Tensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple. Returns: - [`~models.vq_model.VQEncoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` - is returned. + [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a + plain `tuple` is returned. """ h = self.encode(sample).latents diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 71aeb09049..f946e46344 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -16,10 +16,14 @@ from .autoencoders.vq_model import VQEncoderOutput, VQModel class VQEncoderOutput(VQEncoderOutput): - deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." - deprecate("VQEncoderOutput", "0.31", deprecation_message) + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." + deprecate("VQEncoderOutput", "0.31", deprecation_message) + super().__init__(*args, **kwargs) class VQModel(VQModel): - deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." - deprecate("VQModel", "0.31", deprecation_message) + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." + deprecate("VQModel", "0.31", deprecation_message) + super().__init__(*args, **kwargs)