From 84a89cc90c8e6c633d9673d2c99337a7976f1d61 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 18:16:28 +0200 Subject: [PATCH] Fix config deprecation (#3129) * Better deprecation message * Better deprecation message * Better doc string * Fixes * fix more * fix more * Improve __getattr__ * correct more * fix more * fix * Improve more * more improvements * fix more * Apply suggestions from code review Co-authored-by: Pedro Cuenca * make style * Fix all rest & add tests & remove old deprecation fns --------- Co-authored-by: Pedro Cuenca --- .../community/unclip_image_interpolation.py | 12 ++-- .../community/unclip_text_interpolation.py | 12 ++-- src/diffusers/configuration_utils.py | 18 +++++ src/diffusers/models/autoencoder_kl.py | 12 +--- src/diffusers/models/modeling_utils.py | 21 +++++- src/diffusers/models/unet_1d.py | 12 +--- src/diffusers/models/unet_2d.py | 12 +--- src/diffusers/models/unet_2d_condition.py | 12 +--- src/diffusers/pipelines/pipeline_utils.py | 69 +++++++++---------- .../pipeline_text_to_video_zero.py | 2 +- .../pipelines/unclip/pipeline_unclip.py | 12 ++-- .../unclip/pipeline_unclip_image_variation.py | 12 ++-- .../versatile_diffusion/modeling_text_unet.py | 15 +--- ...ipeline_versatile_diffusion_dual_guided.py | 2 +- ...ine_versatile_diffusion_image_variation.py | 2 +- ...eline_versatile_diffusion_text_to_image.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 12 +--- src/diffusers/utils/deprecation_utils.py | 4 +- tests/pipelines/unclip/test_unclip.py | 8 +-- .../unclip/test_unclip_image_variation.py | 13 ++-- tests/schedulers/test_schedulers.py | 44 ++++++++++++ tests/test_modeling_common.py | 47 ++++++++++++- 22 files changed, 209 insertions(+), 146 deletions(-) diff --git a/examples/community/unclip_image_interpolation.py b/examples/community/unclip_image_interpolation.py index d0b54125b6..453ac07af7 100644 --- a/examples/community/unclip_image_interpolation.py +++ b/examples/community/unclip_image_interpolation.py @@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index ac6b73d974..290f453170 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline): self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline): self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 4593043135..772e119fbe 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -118,6 +118,24 @@ class ConfigMixin: self._internal_dict = FrozenDict(internal_dict) + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 5d1c54a9af..1a8a204d80 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, apply_forward_hook, deprecate +from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - @property - def block_out_channels(self): - deprecate( - "block_out_channels", - "1.0.0", - "Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`", - standard_warn=False, - ) - return self.config.block_out_channels - def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6a849f6f0e..5363e63306 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -17,7 +17,7 @@ import inspect import os from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device @@ -32,6 +32,7 @@ from ..utils import ( WEIGHTS_NAME, _add_variant, _get_model_file, + deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, @@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module): def __init__(self): super().__init__() + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) + @property def is_gradient_checkpointing(self) -> bool: """ diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c7755bb3ed..34a1d2b516 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin): fc_dim=block_out_channels[-1] // 4, ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index a83e4917c1..2a6a1b9de5 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin): self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1b982aedc5..b281435693 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, deprecate, logging +from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -447,16 +447,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e20c21aaf..9daab21de7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -507,7 +507,7 @@ class DiffusionPipeline(ConfigMixin): setattr(self, name, module) def __setattr__(self, name: str, value: Any): - if hasattr(self, name) and hasattr(self.config, name): + if name in self.__dict__ and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: @@ -635,26 +635,25 @@ class DiffusionPipeline(ConfigMixin): ) module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if hasattr(self, m)] + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for name in module_names: - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - module.to(torch_device, torch_dtype) - if ( - module.dtype == torch.float16 - and str(torch_device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) + for module in modules: + module.to(torch_device, torch_dtype) + if ( + module.dtype == torch.float16 + and str(torch_device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) return self @property @@ -664,12 +663,12 @@ class DiffusionPipeline(ConfigMixin): `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if hasattr(self, m)] + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device - for name in module_names: - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device return torch.device("cpu") @classmethod @@ -1438,13 +1437,12 @@ class DiffusionPipeline(ConfigMixin): for child in module.children(): fn_recursive_set_mem_eff(child) - module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [m for m in module_names if hasattr(self, m)] + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) + for module in modules: + fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -1471,10 +1469,9 @@ class DiffusionPipeline(ConfigMixin): self.enable_attention_slicing(None) def set_attention_slice(self, slice_size: Optional[int]): - module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [m for m in module_names if hasattr(self, m)] + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) + for module in modules: + module.set_attention_slice(slice_size) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index cf5e6e399a..5b163bbbc8 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): timesteps = self.scheduler.timesteps # Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 3aac39b3a3..abbb48ce8f 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline): self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline): self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 56d522354d..30d74cd36b 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size if decoder_latents is None: decoder_latents = self.prepare_latents( @@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size if super_res_latents is None: super_res_latents = self.prepare_latents( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 35ddfcadc3..4377be1181 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import deprecate, logging +from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -544,19 +544,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - ( - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use" - " `unet.config.in_channels` instead" - ), - standard_warn=False, - ) - return self.config.in_channels - @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 0f385ed661..661a1bd3cf 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 2b47184d77..e3a2ee3703 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index fdca625fd9..26b9be2bfa 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index eaaf497f9c..9fb36db52d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate, randn_tensor +from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -167,16 +167,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.variance_type = variance_type - @property - def num_train_timesteps(self): - deprecate( - "num_train_timesteps", - "1.0.0", - "Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`", - standard_warn=False, - ) - return self.config.num_train_timesteps - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index 6bdda664e1..f482deddd2 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union from packaging import version -def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__ deprecated_kwargs = take_from @@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if warning is not None: warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, FutureWarning, stacklevel=2) + warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index c36fb02b19..29eab0cf80 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -293,16 +293,16 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): prior_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) - shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) + shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size) decoder_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) shape = ( batch_size, - super_res_first.in_channels // 2, - super_res_first.sample_size, - super_res_first.sample_size, + super_res_first.config.in_channels // 2, + super_res_first.config.sample_size, + super_res_first.config.sample_size, ) super_res_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 3cacb0bcad..db88fbd3e2 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -379,16 +379,21 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa dtype = pipe.decoder.dtype batch_size = 1 - shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size) + shape = ( + batch_size, + pipe.decoder.config.in_channels, + pipe.decoder.config.sample_size, + pipe.decoder.config.sample_size, + ) decoder_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) shape = ( batch_size, - pipe.super_res_first.in_channels // 2, - pipe.super_res_first.sample_size, - pipe.super_res_first.sample_size, + pipe.super_res_first.config.in_channels // 2, + pipe.super_res_first.config.sample_size, + pipe.super_res_first.config.sample_size, ) super_res_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index bfbf5cbc79..69cddb36dd 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -596,3 +596,47 @@ class SchedulerCommonTest(unittest.TestCase): new_scheduler = scheduler_class.from_pretrained(tmpdirname) assert scheduler.betas.tolist() == new_scheduler.betas.tolist() + + def test_getattr_is_correct(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + # save some things to test + scheduler.dummy_attribute = 5 + scheduler.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.configuration_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(scheduler, "dummy_attribute") + assert getattr(scheduler, "dummy_attribute") == 5 + assert scheduler.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.schedulers.schedulering_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(scheduler, "save_pretrained") + fn = scheduler.save_pretrained + fn_1 = getattr(scheduler, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown + with self.assertWarns(FutureWarning): + assert scheduler.test_attribute == 5 + + with self.assertWarns(FutureWarning): + assert getattr(scheduler, "test_attribute") == 5 + + with self.assertRaises(AttributeError) as error: + scheduler.does_not_exist + + assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 40aba3b249..4a94a77fca 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -26,8 +26,8 @@ from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel from diffusers.training_utils import EMAModel -from diffusers.utils import torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils import logging, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -155,6 +155,49 @@ class ModelTesterMixin: max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + def test_getattr_is_correct(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + # save some things to test + model.dummy_attribute = 5 + model.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "dummy_attribute") + assert getattr(model, "dummy_attribute") == 5 + assert model.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "save_pretrained") + fn = model.save_pretrained + fn_1 = getattr(model, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown + with self.assertWarns(FutureWarning): + assert model.test_attribute == 5 + + with self.assertWarns(FutureWarning): + assert getattr(model, "test_attribute") == 5 + + with self.assertRaises(AttributeError) as error: + model.does_not_exist + + assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()