1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve docs and type hints (#5759)

* improvement: docs and type hints

* improvement: docs and type hints

minor refactor

* improvement: docs and type hints

* update with suggestions from review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan V S
2023-11-16 14:00:32 +05:30
committed by GitHub
parent ecbe27a07f
commit 038b42db94
19 changed files with 533 additions and 415 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
@@ -126,14 +126,14 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def normalize(images):
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
Normalize an image array to [-1,1].
"""
return 2.0 * images - 1.0
@staticmethod
def denormalize(images):
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
Denormalize an image array to [0,1].
"""
@@ -159,10 +159,10 @@ class VaeImageProcessor(ConfigMixin):
def get_default_height_width(
self,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
):
) -> Tuple[int, int]:
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
@@ -202,12 +202,24 @@ class VaeImageProcessor(ConfigMixin):
def resize(
self,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.
Args:
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor.
height (`int`, *optional*, defaults to `None`):
The height to resize to.
width (`int`, *optional*`, defaults to `None`):
The width to resize to.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The resized image.
"""
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
@@ -227,7 +239,15 @@ class VaeImageProcessor(ConfigMixin):
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
"""
create a mask
Create a mask.
Args:
image (`PIL.Image.Image`):
The image input, should be a PIL image.
Returns:
`PIL.Image.Image`:
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
"""
image[image < 0.5] = 0
image[image >= 0.5] = 1
@@ -327,7 +347,23 @@ class VaeImageProcessor(ConfigMixin):
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
):
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
@@ -390,7 +426,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
super().__init__()
@staticmethod
def numpy_to_pil(images):
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
"""
Convert a NumPy image or a batch of images to a PIL image.
"""
@@ -406,7 +442,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return pil_images
@staticmethod
def rgblike_to_depthmap(image):
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
Args:
image: RGB-like depth image
@@ -416,7 +452,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images):
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
"""
Convert a NumPy depth image or a batch of images to a PIL image.
"""
@@ -441,7 +477,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
):
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"

View File

@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
down_block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
down_block_out_channels: Tuple[int, ...] = (64,),
layers_per_down_block: int = 1,
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
up_block_out_channels: Tuple[int] = (64,),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
up_block_out_channels: Tuple[int, ...] = (64,),
layers_per_up_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -109,7 +109,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_tiling = False
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
@@ -125,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask)
@@ -142,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
decoded = self._decode(z, image, mask).sample
if not return_dict:
@@ -157,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.

View File

@@ -322,13 +322,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=decoded)
def blend_v(self, a, b, blend_extent):
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a, b, blend_extent):
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)

View File

@@ -96,18 +96,18 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
in_channels: int = 3,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: float = False,
force_upcast: bool = False,
scaling_factor: float = 1.0,
):
super().__init__()
@@ -147,33 +147,33 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
def scale_latents(self, x):
def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
def unscale_latents(self, x):
def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self):
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True):
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
@@ -181,7 +181,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
"""
self.use_tiling = use_tiling
def disable_tiling(self):
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
@@ -197,13 +197,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
plain `tuple` is returned.
`torch.FloatTensor`: Encoded batch of images.
"""
# scale of encoder output relative to input
sf = self.spatial_scale_factor
@@ -249,13 +245,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
`torch.FloatTensor`: Encoded batch of images.
"""
# scale of decoder output relative to input
sf = self.spatial_scale_factor

View File

@@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
scaling_factor=0.18215,
latent_channels=4,
encoder_act_fn="silu",
encoder_block_out_channels=(128, 256, 512, 512),
encoder_double_z=True,
encoder_down_block_types=(
scaling_factor: float = 0.18215,
latent_channels: int = 4,
encoder_act_fn: str = "silu",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z: bool = True,
encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
encoder_in_channels=3,
encoder_layers_per_block=2,
encoder_norm_num_groups=32,
encoder_out_channels=4,
decoder_add_attention=False,
decoder_block_out_channels=(320, 640, 1024, 1024),
decoder_down_block_types=(
encoder_in_channels: int = 3,
encoder_layers_per_block: int = 2,
encoder_norm_num_groups: int = 32,
encoder_out_channels: int = 4,
decoder_add_attention: bool = False,
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
decoder_down_block_types: Tuple[str, ...] = (
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
),
decoder_downsample_padding=1,
decoder_in_channels=7,
decoder_layers_per_block=3,
decoder_norm_eps=1e-05,
decoder_norm_num_groups=32,
decoder_num_train_timesteps=1024,
decoder_out_channels=6,
decoder_resnet_time_scale_shift="scale_shift",
decoder_time_embedding_type="learned",
decoder_up_block_types=(
decoder_downsample_padding: int = 1,
decoder_in_channels: int = 7,
decoder_layers_per_block: int = 3,
decoder_norm_eps: float = 1e-05,
decoder_norm_num_groups: int = 32,
decoder_num_train_timesteps: int = 1024,
decoder_out_channels: int = 6,
decoder_resnet_time_scale_shift: str = "scale_shift",
decoder_time_embedding_type: str = "learned",
decoder_up_block_types: Tuple[str, ...] = (
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
@@ -304,8 +304,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
z: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
num_inference_steps=2,
) -> Union[DecoderOutput, torch.FloatTensor]:
num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a, b, blend_extent):
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a, b, blend_extent):
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
@@ -407,7 +407,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
@@ -415,6 +415,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*, defaults to `None`):
Generator to use for sampling.
Returns:
[`DecoderOutput`] or `tuple`:
If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
"""
x = sample
posterior = self.encode(x).latent_dist

View File

@@ -76,7 +76,7 @@ class ControlNetConditioningEmbedding(nn.Module):
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
@@ -171,6 +171,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@@ -182,14 +185,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -197,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
@@ -211,9 +214,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
@@ -426,7 +429,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
@@ -570,7 +573,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
@@ -635,7 +638,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
@@ -653,7 +656,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
"""
The [`ControlNetModel`] forward method.

View File

@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput):
class FlaxControlNetConditioningEmbedding(nn.Module):
conditioning_embedding_channels: int
block_out_channels: Tuple[int] = (16, 32, 96, 256)
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
dtype: jnp.dtype = jnp.float32
def setup(self):
def setup(self) -> None:
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype=self.dtype,
)
def __call__(self, conditioning):
def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
embedding = self.conv_in(conditioning)
embedding = nn.silu(embedding)
@@ -148,17 +148,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
"""
sample_size: int = 32
in_channels: int = 4
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
attention_head_dim: Union[int, Tuple[int, ...]] = 8
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
@@ -166,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
@@ -182,7 +182,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self):
def setup(self) -> None:
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
@@ -312,21 +312,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
def __call__(
self,
sample,
timesteps,
encoder_hidden_states,
controlnet_cond,
sample: jnp.ndarray,
timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states: jnp.ndarray,
controlnet_cond: jnp.ndarray,
conditioning_scale: float = 1.0,
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxControlNetOutput, Tuple]:
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
r"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale: (`float`) the scale factor for controlnet outputs
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
@@ -335,8 +335,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr":

View File

@@ -18,13 +18,14 @@ import inspect
import itertools
import os
import re
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch
from huggingface_hub import create_repo
from torch import Tensor, device, nn
from torch import Tensor, nn
from .. import __version__
from ..utils import (
@@ -61,7 +62,7 @@ if is_accelerate_available():
from accelerate.utils.versions import is_torch_version
def get_parameter_device(parameter: torch.nn.Module):
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module):
return first_tuple[1].device
def get_parameter_dtype(parameter: torch.nn.Module):
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
params = tuple(parameter.parameters())
if len(params) > 0:
@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
)
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
) -> List[str]:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict):
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix=""):
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self):
def enable_gradient_checkpointing(self) -> None:
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def disable_gradient_checkpointing(self):
def disable_gradient_checkpointing(self) -> None:
"""
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
r"""
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
def disable_xformers_memory_efficient_attention(self) -> None:
r"""
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
"""
@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
save_function: Optional[Callable] = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
@@ -910,10 +917,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def _load_pretrained_model(
cls,
model,
state_dict,
state_dict: OrderedDict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
pretrained_model_name_or_path: Union[str, os.PathLike],
ignore_mismatched_sizes: bool = False,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
@@ -1011,7 +1018,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
@property
def device(self) -> device:
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
@@ -1063,7 +1070,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _convert_deprecated_attention_blocks(self, state_dict):
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
@@ -1107,7 +1114,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
def _temp_convert_self_to_deprecated_attention_blocks(self):
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
@@ -1134,10 +1141,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
del module.to_v
del module.to_out
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
def recursive_find_attn_block(module) -> None:
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)

View File

@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module):
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
batch_size: int = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.

View File

@@ -164,7 +164,9 @@ class Upsample2D(nn.Module):
else:
self.Conv2d_0 = conv
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
@@ -256,7 +258,7 @@ class Downsample2D(nn.Module):
else:
self.conv = conv
def forward(self, hidden_states, scale: float = 1.0):
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
@@ -280,7 +282,7 @@ class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution.
Parameters:
channels (`int`):
channels (`int`, optional):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
@@ -292,7 +294,7 @@ class FirUpsample2D(nn.Module):
def __init__(
self,
channels: int = None,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -307,12 +309,12 @@ class FirUpsample2D(nn.Module):
def _upsample_2d(
self,
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.Tensor:
) -> torch.FloatTensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -320,17 +322,21 @@ class FirUpsample2D(nn.Module):
arbitrary order.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*): Integer upsampling factor (default: 2).
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -392,7 +398,7 @@ class FirUpsample2D(nn.Module):
return output
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -418,7 +424,7 @@ class FirDownsample2D(nn.Module):
def __init__(
self,
channels: int = None,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -433,30 +439,35 @@ class FirDownsample2D(nn.Module):
def _downsample_2d(
self,
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.Tensor:
) -> torch.FloatTensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
same datatype as `x`.
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -492,7 +503,7 @@ class FirDownsample2D(nn.Module):
return output
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -682,7 +693,9 @@ class ResnetBlock2D(nn.Module):
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
def forward(self, input_tensor, temb, scale: float = 1.0):
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
@@ -778,7 +791,7 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
activation (`str`, defaults `mish`): Name of the activation function.
activation (`str`, defaults to `mish`): Name of the activation function.
"""
def __init__(
@@ -853,8 +866,8 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -862,14 +875,19 @@ def upsample_2d(
a: multiple of the upsampling factor.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*, default to `2`):
Integer upsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude (default: 1.0).
Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]`
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
@@ -892,8 +910,8 @@ def upsample_2d(
def downsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor:
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -901,14 +919,19 @@ def downsample_2d(
shape is a multiple of the downsampling factor.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
hidden_states (`torch.FloatTensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output: Tensor of the shape `[N, C, H // factor, W // factor]`
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1

View File

@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size: int = 32
in_channels: int = 4
out_channels: int = 4
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
attention_head_dim: Union[int, Tuple[int, ...]] = 8
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift: int = 0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self):
def setup(self) -> None:
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
def __call__(
self,
sample,
timesteps,
encoder_hidden_states,
sample: jnp.ndarray,
timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states: jnp.ndarray,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
mid_block_additional_residual: Optional[jnp.ndarray] = None,
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
r"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
@@ -26,26 +26,26 @@ from .transformer_temporal import TransformerTemporalModel
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=True,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
):
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
resnet_eps: float,
resnet_act_fn: str,
num_attention_heads: int,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
downsample_padding: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = True,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]:
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
@@ -123,28 +123,28 @@ def get_down_block(
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
num_attention_heads,
resolution_idx=None,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=True,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
temporal_num_attention_heads=8,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
):
up_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
add_upsample: bool,
resnet_eps: float,
resnet_act_fn: str,
num_attention_heads: int,
resolution_idx: Optional[int] = None,
resnet_groups: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = True,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]:
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
@@ -236,12 +236,12 @@ class UNetMidBlock3DCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=True,
upcast_attention=False,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
use_linear_projection: bool = True,
upcast_attention: bool = False,
):
super().__init__()
@@ -328,13 +328,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
def forward(
self,
hidden_states,
temb=None,
encoder_hidden_states=None,
attention_mask=None,
num_frames=1,
cross_attention_kwargs=None,
):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
for attn, temp_attn, resnet, temp_conv in zip(
@@ -368,15 +368,15 @@ class CrossAttnDownBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
resnets = []
@@ -454,13 +454,13 @@ class CrossAttnDownBlock3D(nn.Module):
def forward(
self,
hidden_states,
temb=None,
encoder_hidden_states=None,
attention_mask=None,
num_frames=1,
cross_attention_kwargs=None,
):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
# TODO(Patrick, William) - attention mask is not used
output_states = ()
@@ -503,9 +503,9 @@ class DownBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
@@ -552,7 +552,9 @@ class DownBlock3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, num_frames=1):
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -584,15 +586,15 @@ class CrossAttnUpBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resolution_idx=None,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
resolution_idx: Optional[int] = None,
):
super().__init__()
resnets = []
@@ -667,15 +669,15 @@ class CrossAttnUpBlock3D(nn.Module):
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
num_frames=1,
cross_attention_kwargs=None,
):
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -738,9 +740,9 @@ class UpBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
resolution_idx=None,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
resolution_idx: Optional[int] = None,
):
super().__init__()
resnets = []
@@ -784,7 +786,14 @@ class UpBlock3D(nn.Module):
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -833,12 +842,12 @@ class DownBlockMotion(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
temporal_num_attention_heads=1,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
):
super().__init__()
resnets = []
@@ -890,7 +899,13 @@ class DownBlockMotion(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
@@ -944,19 +959,19 @@ class CrossAttnDownBlockMotion(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
):
super().__init__()
resnets = []
@@ -1043,14 +1058,14 @@ class CrossAttnDownBlockMotion(nn.Module):
def forward(
self,
hidden_states,
temb=None,
encoder_hidden_states=None,
attention_mask=None,
num_frames=1,
encoder_attention_mask=None,
cross_attention_kwargs=None,
additional_residuals=None,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
):
output_states = ()
@@ -1121,7 +1136,7 @@ class CrossAttnUpBlockMotion(nn.Module):
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
@@ -1130,18 +1145,18 @@ class CrossAttnUpBlockMotion(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
):
super().__init__()
resnets = []
@@ -1232,8 +1247,8 @@ class CrossAttnUpBlockMotion(nn.Module):
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1,
):
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
@@ -1317,7 +1332,7 @@ class UpBlockMotion(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -1325,12 +1340,12 @@ class UpBlockMotion(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
temporal_norm_num_groups=32,
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
temporal_norm_num_groups: int = 32,
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
):
super().__init__()
resnets = []
@@ -1381,8 +1396,14 @@ class UpBlockMotion(nn.Module):
self.resolution_idx = resolution_idx
def forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
):
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size=None,
scale: float = 1.0,
num_frames: int = 1,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1457,16 +1478,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
temporal_num_attention_heads=1,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: float = False,
use_linear_projection: float = False,
upcast_attention: float = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
):
super().__init__()
@@ -1560,7 +1581,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1,
num_frames: int = 1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)

View File

@@ -98,14 +98,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
up_block_types: Tuple[str, ...] = (
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -302,7 +307,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
@@ -404,7 +409,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def enable_forward_chunking(self, chunk_size=None, dim=0):
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
@@ -460,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
@@ -510,7 +515,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
r"""
The [`UNet3DConditionModel`] forward method.

View File

@@ -50,14 +50,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MotionModules(nn.Module):
def __init__(
self,
in_channels,
layers_per_block=2,
num_attention_heads=8,
attention_bias=False,
cross_attention_dim=None,
activation_fn="geglu",
norm_num_groups=32,
max_seq_length=32,
in_channels: int,
layers_per_block: int = 2,
num_attention_heads: int = 8,
attention_bias: bool = False,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
norm_num_groups: int = 32,
max_seq_length: int = 32,
):
super().__init__()
self.motion_modules = nn.ModuleList([])
@@ -82,13 +82,13 @@ class MotionAdapter(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
block_out_channels=(320, 640, 1280, 1280),
motion_layers_per_block=2,
motion_mid_block_layers_per_block=1,
motion_num_attention_heads=8,
motion_norm_num_groups=32,
motion_max_seq_length=32,
use_motion_mid_block=True,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
motion_layers_per_block: int = 2,
motion_mid_block_layers_per_block: int = 1,
motion_num_attention_heads: int = 8,
motion_norm_num_groups: int = 32,
motion_max_seq_length: int = 32,
use_motion_mid_block: bool = True,
):
"""Container to store AnimateDiff Motion Modules
@@ -182,29 +182,29 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockMotion",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockMotion",
"CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
motion_max_seq_length: Optional[int] = 32,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
):
@@ -448,7 +448,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return model
def freeze_unet2d_params(self):
def freeze_unet2d_params(self) -> None:
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
unfrozen for fine tuning.
"""
@@ -472,9 +472,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for param in motion_modules.parameters():
param.requires_grad = True
return
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]):
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
for i, down_block in enumerate(motion_adapter.down_blocks):
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
for i, up_block in enumerate(motion_adapter.up_blocks):
@@ -492,7 +490,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
):
) -> None:
state_dict = self.state_dict()
# Extract all motion modules
@@ -582,7 +580,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size=None, dim=0):
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
@@ -612,7 +610,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
@@ -624,7 +622,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
def set_default_attn_processor(self) -> None:
"""
Disables custom attention processors and sets the default attention implementation.
"""
@@ -639,12 +637,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -669,7 +667,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
def disable_freeu(self) -> None:
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
@@ -688,7 +686,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
r"""
The [`UNetMotionModel`] forward method.

View File

@@ -148,7 +148,9 @@ class VQModel(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def forward(
self, sample: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]:
r"""
The [`VQModel`] forward method.

View File

@@ -37,7 +37,7 @@ class SchedulerType(Enum):
PIECEWISE_CONSTANT = "piecewise_constant"
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
@@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
@@ -78,7 +78,7 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
@@ -120,7 +120,9 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
def get_linear_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
@@ -151,7 +153,7 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
@@ -185,7 +187,7 @@ def get_cosine_schedule_with_warmup(
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
):
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
@@ -219,8 +221,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
def get_polynomial_decay_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
):
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
lr_end: float = 1e-7,
power: float = 1.0,
last_epoch: int = -1,
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
@@ -288,7 +295,7 @@ def get_scheduler(
num_cycles: int = 1,
power: float = 1.0,
last_epoch: int = -1,
):
) -> LambdaLR:
"""
Unified API to get any scheduler from its name.

View File

@@ -28,7 +28,7 @@ from logging import (
WARN, # NOQA
WARNING, # NOQA
)
from typing import Optional
from typing import Dict, Optional
from tqdm import auto as tqdm_lib
@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING
_tqdm_active = True
def _get_default_logging_level():
def _get_default_logging_level() -> int:
"""
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
_default_handler = None
def get_log_levels_dict():
def get_log_levels_dict() -> Dict[str, int]:
return log_levels
@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
def set_verbosity_info() -> None:
"""Set the verbosity to the `INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
def set_verbosity_warning() -> None:
"""Set the verbosity to the `WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
def set_verbosity_debug() -> None:
"""Set the verbosity to the `DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
def set_verbosity_error() -> None:
"""Set the verbosity to the `ERROR` level."""
return set_verbosity(ERROR)
@@ -263,7 +263,7 @@ def reset_format() -> None:
handler.setFormatter(None)
def warning_advice(self, *args, **kwargs):
def warning_advice(self, *args, **kwargs) -> None:
"""
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
return bool(_tqdm_active)
def enable_progress_bar():
def enable_progress_bar() -> None:
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
def disable_progress_bar():
def disable_progress_bar() -> None:
"""Disable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False

View File

@@ -24,7 +24,7 @@ import numpy as np
from .import_utils import is_torch_available
def is_tensor(x):
def is_tensor(x) -> bool:
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict):
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
def __post_init__(self):
def __post_init__(self) -> None:
class_fields = fields(self)
# Safety and consistency checks
@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict):
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
def __getitem__(self, k: Any) -> Any:
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
def __setattr__(self, name: Any, value: Any) -> None:
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict):
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]:
def to_tuple(self) -> Tuple[Any, ...]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""

View File

@@ -82,14 +82,14 @@ def randn_tensor(
return latents
def is_compiled_module(module):
def is_compiled_module(module) -> bool:
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def fourier_filter(x_in, threshold, scale):
def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here: