mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix config deprecation (#3129)
* Better deprecation message * Better deprecation message * Better doc string * Fixes * fix more * fix more * Improve __getattr__ * correct more * fix more * fix * Improve more * more improvements * fix more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * make style * Fix all rest & add tests & remove old deprecation fns --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -118,6 +118,24 @@ class ConfigMixin:
|
||||
|
||||
self._internal_dict = FrozenDict(internal_dict)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
||||
|
||||
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return self._internal_dict[name]
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook, deprecate
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
def block_out_channels(self):
|
||||
deprecate(
|
||||
"block_out_channels",
|
||||
"1.0.0",
|
||||
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.block_out_channels
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
@@ -32,6 +32,7 @@ from ..utils import (
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
||||
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
||||
return self._internal_dict[name]
|
||||
|
||||
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
return super().__getattr__(name)
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -447,16 +447,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
|
||||
@@ -507,7 +507,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
setattr(self, name, module)
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
if hasattr(self, name) and hasattr(self.config, name):
|
||||
if name in self.__dict__ and hasattr(self.config, name):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
@@ -635,26 +635,25 @@ class DiffusionPipeline(ConfigMixin):
|
||||
)
|
||||
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
for module in modules:
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -664,12 +663,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module in modules:
|
||||
return module.device
|
||||
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
@classmethod
|
||||
@@ -1438,13 +1437,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
for module in modules:
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@@ -1471,10 +1469,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def set_attention_slice(self, slice_size: Optional[int]):
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
for module in modules:
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
@@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
if decoder_latents is None:
|
||||
decoder_latents = self.prepare_latents(
|
||||
@@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
if super_res_latents is None:
|
||||
super_res_latents = self.prepare_latents(
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -544,19 +544,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
(
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
|
||||
" `unet.config.in_channels` instead"
|
||||
),
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
|
||||
@@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from ..utils import BaseOutput, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@@ -167,16 +167,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
@property
|
||||
def num_train_timesteps(self):
|
||||
deprecate(
|
||||
"num_train_timesteps",
|
||||
"1.0.0",
|
||||
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from packaging import version
|
||||
|
||||
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
|
||||
from .. import __version__
|
||||
|
||||
deprecated_kwargs = take_from
|
||||
@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
|
||||
|
||||
if warning is not None:
|
||||
warning = warning + " " if standard_warn else ""
|
||||
warnings.warn(warning + message, FutureWarning, stacklevel=2)
|
||||
warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
|
||||
|
||||
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
|
||||
call_frame = inspect.getouterframes(inspect.currentframe())[1]
|
||||
|
||||
@@ -293,16 +293,16 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prior_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
|
||||
shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
super_res_first.in_channels // 2,
|
||||
super_res_first.sample_size,
|
||||
super_res_first.sample_size,
|
||||
super_res_first.config.in_channels // 2,
|
||||
super_res_first.config.sample_size,
|
||||
super_res_first.config.sample_size,
|
||||
)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
|
||||
@@ -379,16 +379,21 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
dtype = pipe.decoder.dtype
|
||||
batch_size = 1
|
||||
|
||||
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
|
||||
shape = (
|
||||
batch_size,
|
||||
pipe.decoder.config.in_channels,
|
||||
pipe.decoder.config.sample_size,
|
||||
pipe.decoder.config.sample_size,
|
||||
)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
pipe.super_res_first.in_channels // 2,
|
||||
pipe.super_res_first.sample_size,
|
||||
pipe.super_res_first.sample_size,
|
||||
pipe.super_res_first.config.in_channels // 2,
|
||||
pipe.super_res_first.config.sample_size,
|
||||
pipe.super_res_first.config.sample_size,
|
||||
)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
|
||||
@@ -596,3 +596,47 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
# save some things to test
|
||||
scheduler.dummy_attribute = 5
|
||||
scheduler.register_to_config(test_attribute=5)
|
||||
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(scheduler, "dummy_attribute")
|
||||
assert getattr(scheduler, "dummy_attribute") == 5
|
||||
assert scheduler.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
logger = logging.get_logger("diffusers.schedulers.schedulering_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(scheduler, "save_pretrained")
|
||||
fn = scheduler.save_pretrained
|
||||
fn_1 = getattr(scheduler, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
# warning should be thrown
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert scheduler.test_attribute == 5
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert getattr(scheduler, "test_attribute") == 5
|
||||
|
||||
with self.assertRaises(AttributeError) as error:
|
||||
scheduler.does_not_exist
|
||||
|
||||
assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@@ -26,8 +26,8 @@ from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils import logging, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
@@ -155,6 +155,49 @@ class ModelTesterMixin:
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
# save some things to test
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
# warning should be thrown
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with self.assertRaises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user