diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 28a12f2d13..de60c46eb2 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -126,14 +126,14 @@ class VaeImageProcessor(ConfigMixin): return images @staticmethod - def normalize(images): + def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Normalize an image array to [-1,1]. """ return 2.0 * images - 1.0 @staticmethod - def denormalize(images): + def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Denormalize an image array to [0,1]. """ @@ -159,10 +159,10 @@ class VaeImageProcessor(ConfigMixin): def get_default_height_width( self, - image: [PIL.Image.Image, np.ndarray, torch.Tensor], + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, - ): + ) -> Tuple[int, int]: """ This function return the height and width that are downscaled to the next integer multiple of `vae_scale_factor`. @@ -202,12 +202,24 @@ class VaeImageProcessor(ConfigMixin): def resize( self, - image: [PIL.Image.Image, np.ndarray, torch.Tensor], + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, - ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. + + Args: + image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. + height (`int`, *optional*, defaults to `None`): + The height to resize to. + width (`int`, *optional*`, defaults to `None`): + The width to resize to. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The resized image. """ if isinstance(image, PIL.Image.Image): image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) @@ -227,7 +239,15 @@ class VaeImageProcessor(ConfigMixin): def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: """ - create a mask + Create a mask. + + Args: + image (`PIL.Image.Image`): + The image input, should be a PIL image. + + Returns: + `PIL.Image.Image`: + The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1. """ image[image < 0.5] = 0 image[image >= 0.5] = 1 @@ -327,7 +347,23 @@ class VaeImageProcessor(ConfigMixin): image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ): + ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.FloatTensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + The postprocessed image. + """ if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" @@ -390,7 +426,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): super().__init__() @staticmethod - def numpy_to_pil(images): + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: """ Convert a NumPy image or a batch of images to a PIL image. """ @@ -406,7 +442,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): return pil_images @staticmethod - def rgblike_to_depthmap(image): + def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Args: image: RGB-like depth image @@ -416,7 +452,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): """ return image[:, :, 1] * 2**8 + image[:, :, 2] - def numpy_to_depth(self, images): + def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: """ Convert a NumPy depth image or a batch of images to a PIL image. """ @@ -441,7 +477,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ): + ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.FloatTensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + The postprocessed image. + """ if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index 9f0fa62d34..656683b43f 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - down_block_out_channels: Tuple[int] = (64,), + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + down_block_out_channels: Tuple[int, ...] = (64,), layers_per_down_block: int = 1, - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - up_block_out_channels: Tuple[int] = (64,), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + up_block_out_channels: Tuple[int, ...] = (64,), layers_per_up_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, @@ -109,7 +109,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): self.use_tiling = False @apply_forward_hook - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) @@ -125,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): image: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, return_dict: bool = True, - ) -> Union[DecoderOutput, torch.FloatTensor]: + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: z = self.post_quant_conv(z) dec = self.decoder(z, image, mask) @@ -142,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): image: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, return_dict: bool = True, - ) -> Union[DecoderOutput, torch.FloatTensor]: + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: decoded = self._decode(z, image, mask).sample if not return_dict: @@ -157,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: r""" Args: sample (`torch.FloatTensor`): Input sample. diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index ac616530a6..9003d982b3 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -322,13 +322,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return DecoderOutput(sample=decoded) - def blend_v(self, a, b, blend_extent): + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b - def blend_h(self, a, b, blend_extent): + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index 15bd53ff99..0df97ed228 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -96,18 +96,18 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - in_channels=3, - out_channels=3, - encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), - decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), act_fn: str = "relu", latent_channels: int = 4, upsampling_scaling_factor: int = 2, - num_encoder_blocks: Tuple[int] = (1, 3, 3, 3), - num_decoder_blocks: Tuple[int] = (3, 3, 3, 1), + num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), latent_magnitude: int = 3, latent_shift: float = 0.5, - force_upcast: float = False, + force_upcast: bool = False, scaling_factor: float = 1.0, ): super().__init__() @@ -147,33 +147,33 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): self.tile_sample_min_size = 512 self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (EncoderTiny, DecoderTiny)): module.gradient_checkpointing = value - def scale_latents(self, x): + def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor: """raw latents -> [0, 1]""" return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) - def unscale_latents(self, x): + def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor: """[0, 1] -> raw latents""" return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) - def enable_slicing(self): + def enable_slicing(self) -> None: r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True - def disable_slicing(self): + def disable_slicing(self) -> None: r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False - def enable_tiling(self, use_tiling: bool = True): + def enable_tiling(self, use_tiling: bool = True) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow @@ -181,7 +181,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): """ self.use_tiling = use_tiling - def disable_tiling(self): + def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. @@ -197,13 +197,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): Args: x (`torch.FloatTensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. Returns: - [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`: - If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a - plain `tuple` is returned. + `torch.FloatTensor`: Encoded batch of images. """ # scale of encoder output relative to input sf = self.spatial_scale_factor @@ -249,13 +245,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): Args: x (`torch.FloatTensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. + `torch.FloatTensor`: Encoded batch of images. """ # scale of decoder output relative to input sf = self.spatial_scale_factor diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py index b182733318..a2d82e2565 100644 --- a/src/diffusers/models/consistency_decoder_vae.py +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - scaling_factor=0.18215, - latent_channels=4, - encoder_act_fn="silu", - encoder_block_out_channels=(128, 256, 512, 512), - encoder_double_z=True, - encoder_down_block_types=( + scaling_factor: float = 0.18215, + latent_channels: int = 4, + encoder_act_fn: str = "silu", + encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + encoder_double_z: bool = True, + encoder_down_block_types: Tuple[str, ...] = ( "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", ), - encoder_in_channels=3, - encoder_layers_per_block=2, - encoder_norm_num_groups=32, - encoder_out_channels=4, - decoder_add_attention=False, - decoder_block_out_channels=(320, 640, 1024, 1024), - decoder_down_block_types=( + encoder_in_channels: int = 3, + encoder_layers_per_block: int = 2, + encoder_norm_num_groups: int = 32, + encoder_out_channels: int = 4, + decoder_add_attention: bool = False, + decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024), + decoder_down_block_types: Tuple[str, ...] = ( "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", ), - decoder_downsample_padding=1, - decoder_in_channels=7, - decoder_layers_per_block=3, - decoder_norm_eps=1e-05, - decoder_norm_num_groups=32, - decoder_num_train_timesteps=1024, - decoder_out_channels=6, - decoder_resnet_time_scale_shift="scale_shift", - decoder_time_embedding_type="learned", - decoder_up_block_types=( + decoder_downsample_padding: int = 1, + decoder_in_channels: int = 7, + decoder_layers_per_block: int = 3, + decoder_norm_eps: float = 1e-05, + decoder_norm_num_groups: int = 32, + decoder_num_train_timesteps: int = 1024, + decoder_out_channels: int = 6, + decoder_resnet_time_scale_shift: str = "scale_shift", + decoder_time_embedding_type: str = "learned", + decoder_up_block_types: Tuple[str, ...] = ( "ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D", @@ -304,8 +304,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): z: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, - num_inference_steps=2, - ) -> Union[DecoderOutput, torch.FloatTensor]: + num_inference_steps: int = 2, + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: z = (z * self.config.scaling_factor - self.means) / self.stds scale_factor = 2 ** (len(self.config.block_out_channels) - 1) @@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): return DecoderOutput(sample=x_0) # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v - def blend_v(self, a, b, blend_extent): + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h - def blend_h(self, a, b, blend_extent): + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) @@ -407,7 +407,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: r""" Args: sample (`torch.FloatTensor`): Input sample. @@ -415,6 +415,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*, defaults to `None`): + Generator to use for sampling. + + Returns: + [`DecoderOutput`] or `tuple`: + If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ x = sample posterior = self.encode(x).latent_dist diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 052335f6c5..220e34593c 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -76,7 +76,7 @@ class ControlNetConditioningEmbedding(nn.Module): self, conditioning_embedding_channels: int, conditioning_channels: int = 3, - block_out_channels: Tuple[int] = (16, 32, 96, 256), + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() @@ -171,6 +171,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. """ _supports_gradient_checkpointing = True @@ -182,14 +185,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, - down_block_types: Tuple[str] = ( + down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, @@ -197,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, @@ -211,9 +214,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), global_pool_conditions: bool = False, - addition_embed_type_num_heads=64, + addition_embed_type_num_heads: int = 64, ): super().__init__() @@ -426,7 +429,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, ): r""" @@ -570,7 +573,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): self.set_attn_processor(processor, _remove_lora=True) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. @@ -635,7 +638,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value @@ -653,7 +656,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: """ The [`ControlNetModel`] forward method. diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 076e618321..10059ffd6f 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput): class FlaxControlNetConditioningEmbedding(nn.Module): conditioning_embedding_channels: int - block_out_channels: Tuple[int] = (16, 32, 96, 256) + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256) dtype: jnp.dtype = jnp.float32 - def setup(self): + def setup(self) -> None: self.conv_in = nn.Conv( self.block_out_channels[0], kernel_size=(3, 3), @@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module): dtype=self.dtype, ) - def __call__(self, conditioning): + def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray: embedding = self.conv_in(conditioning) embedding = nn.silu(embedding) @@ -148,17 +148,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): """ sample_size: int = 32 in_channels: int = 4 - down_block_types: Tuple[str] = ( + down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) - only_cross_attention: Union[bool, Tuple[bool]] = False - block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + only_cross_attention: Union[bool, Tuple[bool, ...]] = False + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: Union[int, Tuple[int]] = 8 - num_attention_heads: Optional[Union[int, Tuple[int]]] = None + attention_head_dim: Union[int, Tuple[int, ...]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False @@ -166,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos: bool = True freq_shift: int = 0 controlnet_conditioning_channel_order: str = "rgb" - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) + conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256) def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors @@ -182,7 +182,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] - def setup(self): + def setup(self) -> None: block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 @@ -312,21 +312,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): def __call__( self, - sample, - timesteps, - encoder_hidden_states, - controlnet_cond, + sample: jnp.ndarray, + timesteps: Union[jnp.ndarray, float, int], + encoder_hidden_states: jnp.ndarray, + controlnet_cond: jnp.ndarray, conditioning_scale: float = 1.0, return_dict: bool = True, train: bool = False, - ) -> Union[FlaxControlNetOutput, Tuple]: + ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor - conditioning_scale: (`float`) the scale factor for controlnet outputs + conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. @@ -335,8 +335,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order if channel_order == "bgr": diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5fe3c5602f..4a9483feb4 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -18,13 +18,14 @@ import inspect import itertools import os import re +from collections import OrderedDict from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union import safetensors import torch from huggingface_hub import create_repo -from torch import Tensor, device, nn +from torch import Tensor, nn from .. import __version__ from ..utils import ( @@ -61,7 +62,7 @@ if is_accelerate_available(): from accelerate.utils.versions import is_torch_version -def get_parameter_device(parameter: torch.nn.Module): +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: try: parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device @@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module): return first_tuple[1].device -def get_parameter_dtype(parameter: torch.nn.Module): +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: try: params = tuple(parameter.parameters()) if len(params) > 0: @@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ) -def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None): +def load_model_dict_into_meta( + model, + state_dict: OrderedDict, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + model_name_or_path: Optional[str] = None, +) -> List[str]: device = device or torch.device("cpu") dtype = dtype or torch.float32 @@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_ return unexpected_keys -def _load_state_dict_into_model(model_to_load, state_dict): +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() @@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict): # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix=""): + def load(module: torch.nn.Module, prefix: str = ""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) @@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self) -> None: """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). @@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) - def disable_gradient_checkpointing(self): + def disable_gradient_checkpointing(self) -> None: """ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). @@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) - def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). @@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): """ self.set_use_memory_efficient_attention_xformers(True, attention_op) - def disable_xformers_memory_efficient_attention(self): + def disable_xformers_memory_efficient_attention(self) -> None: r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ @@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, - save_function: Callable = None, + save_function: Optional[Callable] = None, safe_serialization: bool = True, variant: Optional[str] = None, push_to_hub: bool = False, @@ -910,10 +917,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): def _load_pretrained_model( cls, model, - state_dict, + state_dict: OrderedDict, resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() @@ -1011,7 +1018,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs @property - def device(self) -> device: + def device(self) -> torch.device: """ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). @@ -1063,7 +1070,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - def _convert_deprecated_attention_blocks(self, state_dict): + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): @@ -1107,7 +1114,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") - def _temp_convert_self_to_deprecated_attention_blocks(self): + def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: deprecated_attention_block_modules = [] def recursive_find_attn_block(module): @@ -1134,10 +1141,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): del module.to_v del module.to_out - def _undo_temp_convert_self_to_deprecated_attention_blocks(self): + def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: deprecated_attention_block_modules = [] - def recursive_find_attn_block(module): + def recursive_find_attn_block(module) -> None: if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index cedeff18f3..11d2a34474 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module): def forward( self, timestep: torch.Tensor, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - batch_size: int = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c47ade0348..7a48d343a5 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -164,7 +164,9 @@ class Upsample2D(nn.Module): else: self.Conv2d_0 = conv - def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0 + ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -256,7 +258,7 @@ class Downsample2D(nn.Module): else: self.conv = conv - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: @@ -280,7 +282,7 @@ class FirUpsample2D(nn.Module): """A 2D FIR upsampling layer with an optional convolution. Parameters: - channels (`int`): + channels (`int`, optional): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. @@ -292,7 +294,7 @@ class FirUpsample2D(nn.Module): def __init__( self, - channels: int = None, + channels: Optional[int] = None, out_channels: Optional[int] = None, use_conv: bool = False, fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), @@ -307,12 +309,12 @@ class FirUpsample2D(nn.Module): def _upsample_2d( self, - hidden_states: torch.Tensor, - weight: Optional[torch.Tensor] = None, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1, - ) -> torch.Tensor: + ) -> torch.FloatTensor: """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more @@ -320,17 +322,21 @@ class FirUpsample2D(nn.Module): arbitrary order. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*): Integer upsampling factor (default: 2). + gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 @@ -392,7 +398,7 @@ class FirUpsample2D(nn.Module): return output - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -418,7 +424,7 @@ class FirDownsample2D(nn.Module): def __init__( self, - channels: int = None, + channels: Optional[int] = None, out_channels: Optional[int] = None, use_conv: bool = False, fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), @@ -433,30 +439,35 @@ class FirDownsample2D(nn.Module): def _downsample_2d( self, - hidden_states: torch.Tensor, - weight: Optional[torch.Tensor] = None, + hidden_states: torch.FloatTensor, + weight: Optional[torch.FloatTensor] = None, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1, - ) -> torch.Tensor: + ) -> torch.FloatTensor: """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.FloatTensor`, *optional*): Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and - same datatype as `x`. + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 @@ -492,7 +503,7 @@ class FirDownsample2D(nn.Module): return output - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) @@ -682,7 +693,9 @@ class ResnetBlock2D(nn.Module): in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) - def forward(self, input_tensor, temb, scale: float = 1.0): + def forward( + self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0 + ) -> torch.FloatTensor: hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": @@ -778,7 +791,7 @@ class Conv1dBlock(nn.Module): out_channels (`int`): Number of output channels. kernel_size (`int` or `tuple`): Size of the convolving kernel. n_groups (`int`, default `8`): Number of groups to separate the channels into. - activation (`str`, defaults `mish`): Name of the activation function. + activation (`str`, defaults to `mish`): Name of the activation function. """ def __init__( @@ -853,8 +866,8 @@ class ResidualTemporalBlock1D(nn.Module): def upsample_2d( - hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 -) -> torch.Tensor: + hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.FloatTensor: r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified @@ -862,14 +875,19 @@ def upsample_2d( a: multiple of the upsampling factor. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + hidden_states (`torch.FloatTensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*, default to `2`): + Integer upsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude (default: 1.0). Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: @@ -892,8 +910,8 @@ def upsample_2d( def downsample_2d( - hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 -) -> torch.Tensor: + hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 +) -> torch.FloatTensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the @@ -901,14 +919,19 @@ def downsample_2d( shape is a multiple of the downsampling factor. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + hidden_states (`torch.FloatTensor`) + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.FloatTensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` + output (`torch.FloatTensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 770cbf09cc..13f53e16e7 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 - down_block_types: Tuple[str] = ( + down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") only_cross_attention: Union[bool, Tuple[bool]] = False - block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: Union[int, Tuple[int]] = 8 - num_attention_heads: Optional[Union[int, Tuple[int]]] = None + attention_head_dim: Union[int, Tuple[int, ...]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False @@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False - transformer_layers_per_block: Union[int, Tuple[int]] = 1 + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1 addition_embed_type: Optional[str] = None addition_time_embed_dim: Optional[int] = None addition_embed_type_num_heads: int = 64 @@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): } return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] - def setup(self): + def setup(self) -> None: block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 @@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): def __call__( self, - sample, - timesteps, - encoder_hidden_states, + sample: jnp.ndarray, + timesteps: Union[jnp.ndarray, float, int], + encoder_hidden_states: jnp.ndarray, added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, - down_block_additional_residuals=None, - mid_block_additional_residual=None, + down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None, + mid_block_additional_residual: Optional[jnp.ndarray] = None, return_dict: bool = True, train: bool = False, - ) -> Union[FlaxUNet2DConditionOutput, Tuple]: + ) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 97a1f1037c..767ab846d5 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn @@ -26,26 +26,26 @@ from .transformer_temporal import TransformerTemporalModel def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - num_attention_heads, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=True, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - temporal_num_attention_heads=8, - temporal_max_seq_length=32, -): + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, +) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]: if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, @@ -123,28 +123,28 @@ def get_down_block( def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - num_attention_heads, - resolution_idx=None, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=True, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - temporal_num_attention_heads=8, - temporal_cross_attention_dim=None, - temporal_max_seq_length=32, -): + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, +) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]: if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, @@ -236,12 +236,12 @@ class UNetMidBlock3DCrossAttn(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=True, - upcast_attention=False, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, ): super().__init__() @@ -328,13 +328,13 @@ class UNetMidBlock3DCrossAttn(nn.Module): def forward( self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - num_frames=1, - cross_attention_kwargs=None, - ): + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip( @@ -368,15 +368,15 @@ class CrossAttnDownBlock3D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, ): super().__init__() resnets = [] @@ -454,13 +454,13 @@ class CrossAttnDownBlock3D(nn.Module): def forward( self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - num_frames=1, - cross_attention_kwargs=None, - ): + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # TODO(Patrick, William) - attention mask is not used output_states = () @@ -503,9 +503,9 @@ class DownBlock3D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] @@ -552,7 +552,9 @@ class DownBlock3D(nn.Module): self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, num_frames=1): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1 + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): @@ -584,15 +586,15 @@ class CrossAttnUpBlock3D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resolution_idx=None, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: Optional[int] = None, ): super().__init__() resnets = [] @@ -667,15 +669,15 @@ class CrossAttnUpBlock3D(nn.Module): def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - num_frames=1, - cross_attention_kwargs=None, - ): + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -738,9 +740,9 @@ class UpBlock3D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - resolution_idx=None, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: Optional[int] = None, ): super().__init__() resnets = [] @@ -784,7 +786,14 @@ class UpBlock3D(nn.Module): self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -833,12 +842,12 @@ class DownBlockMotion(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - temporal_num_attention_heads=1, - temporal_cross_attention_dim=None, - temporal_max_seq_length=32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, ): super().__init__() resnets = [] @@ -890,7 +899,13 @@ class DownBlockMotion(nn.Module): self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = 1, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () blocks = zip(self.resnets, self.motion_modules) @@ -944,19 +959,19 @@ class CrossAttnDownBlockMotion(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", - temporal_cross_attention_dim=None, - temporal_num_attention_heads=8, - temporal_max_seq_length=32, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, ): super().__init__() resnets = [] @@ -1043,14 +1058,14 @@ class CrossAttnDownBlockMotion(nn.Module): def forward( self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - num_frames=1, - encoder_attention_mask=None, - cross_attention_kwargs=None, - additional_residuals=None, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.FloatTensor] = None, ): output_states = () @@ -1121,7 +1136,7 @@ class CrossAttnUpBlockMotion(nn.Module): out_channels: int, prev_output_channel: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, @@ -1130,18 +1145,18 @@ class CrossAttnUpBlockMotion(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", - temporal_cross_attention_dim=None, - temporal_num_attention_heads=8, - temporal_max_seq_length=32, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, ): super().__init__() resnets = [] @@ -1232,8 +1247,8 @@ class CrossAttnUpBlockMotion(nn.Module): upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - num_frames=1, - ): + num_frames: int = 1, + ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) @@ -1317,7 +1332,7 @@ class UpBlockMotion(nn.Module): prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -1325,12 +1340,12 @@ class UpBlockMotion(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - temporal_norm_num_groups=32, - temporal_cross_attention_dim=None, - temporal_num_attention_heads=8, - temporal_max_seq_length=32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_norm_num_groups: int = 32, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, ): super().__init__() resnets = [] @@ -1381,8 +1396,14 @@ class UpBlockMotion(nn.Module): self.resolution_idx = resolution_idx def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1 - ): + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size=None, + scale: float = 1.0, + num_frames: int = 1, + ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1457,16 +1478,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - attention_type="default", - temporal_num_attention_heads=1, - temporal_cross_attention_dim=None, - temporal_max_seq_length=32, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: float = False, + use_linear_projection: float = False, + upcast_attention: float = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, ): super().__init__() @@ -1560,7 +1581,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - num_frames=1, + num_frames: int = 1, ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index f7ba9388a8..c6710256ef 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -98,14 +98,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - down_block_types: Tuple[str] = ( + down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, @@ -302,7 +307,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. @@ -404,7 +409,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def enable_forward_chunking(self, chunk_size=None, dim=0): + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). @@ -460,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) self.set_attn_processor(processor, _remove_lora=True) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value @@ -510,7 +515,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[UNet3DConditionOutput, Tuple]: + ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]: r""" The [`UNet3DConditionModel`] forward method. diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 5d528a34ec..ab84b4de13 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -50,14 +50,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class MotionModules(nn.Module): def __init__( self, - in_channels, - layers_per_block=2, - num_attention_heads=8, - attention_bias=False, - cross_attention_dim=None, - activation_fn="geglu", - norm_num_groups=32, - max_seq_length=32, + in_channels: int, + layers_per_block: int = 2, + num_attention_heads: int = 8, + attention_bias: bool = False, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + norm_num_groups: int = 32, + max_seq_length: int = 32, ): super().__init__() self.motion_modules = nn.ModuleList([]) @@ -82,13 +82,13 @@ class MotionAdapter(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - block_out_channels=(320, 640, 1280, 1280), - motion_layers_per_block=2, - motion_mid_block_layers_per_block=1, - motion_num_attention_heads=8, - motion_norm_num_groups=32, - motion_max_seq_length=32, - use_motion_mid_block=True, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: int = 2, + motion_mid_block_layers_per_block: int = 1, + motion_num_attention_heads: int = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, ): """Container to store AnimateDiff Motion Modules @@ -182,29 +182,29 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - down_block_types: Tuple[str] = ( + down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "DownBlockMotion", ), - up_block_types: Tuple[str] = ( + up_block_types: Tuple[str, ...] = ( "UpBlockMotion", "CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion", ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, + norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, use_linear_projection: bool = False, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - motion_max_seq_length: Optional[int] = 32, + num_attention_heads: Union[int, Tuple[int, ...]] = 8, + motion_max_seq_length: int = 32, motion_num_attention_heads: int = 8, use_motion_mid_block: int = True, ): @@ -448,7 +448,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): return model - def freeze_unet2d_params(self): + def freeze_unet2d_params(self) -> None: """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules unfrozen for fine tuning. """ @@ -472,9 +472,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): for param in motion_modules.parameters(): param.requires_grad = True - return - - def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: for i, down_block in enumerate(motion_adapter.down_blocks): self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) for i, up_block in enumerate(motion_adapter.up_blocks): @@ -492,7 +490,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): variant: Optional[str] = None, push_to_hub: bool = False, **kwargs, - ): + ) -> None: state_dict = self.state_dict() # Extract all motion modules @@ -582,7 +580,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size=None, dim=0): + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). @@ -612,7 +610,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking - def disable_forward_chunking(self): + def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) @@ -624,7 +622,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): + def set_default_attn_processor(self) -> None: """ Disables custom attention processors and sets the default attention implementation. """ @@ -639,12 +637,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): self.set_attn_processor(processor, _remove_lora=True) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): module.gradient_checkpointing = value # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu - def enable_freeu(self, s1, s2, b1, b2): + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. @@ -669,7 +667,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu - def disable_freeu(self): + def disable_freeu(self) -> None: """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): @@ -688,7 +686,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[UNet3DConditionOutput, Tuple]: + ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: r""" The [`UNetMotionModel`] forward method. diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 0c93b9142b..f4a6c8fb22 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -148,7 +148,9 @@ class VQModel(ModelMixin, ConfigMixin): return DecoderOutput(sample=dec) - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def forward( + self, sample: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]: r""" The [`VQModel`] forward method. diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 46e6125a0f..678d2c12cf 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -37,7 +37,7 @@ class SchedulerType(Enum): PIECEWISE_CONSTANT = "piecewise_constant" -def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. @@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) -def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. @@ -78,7 +78,7 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) -def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. @@ -120,7 +120,9 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_ return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) -def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): +def get_linear_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1 +) -> LambdaLR: """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. @@ -151,7 +153,7 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 -): +) -> LambdaLR: """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the @@ -185,7 +187,7 @@ def get_cosine_schedule_with_warmup( def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 -): +) -> LambdaLR: """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases @@ -219,8 +221,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup( def get_polynomial_decay_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 -): + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float = 1e-7, + power: float = 1.0, + last_epoch: int = -1, +) -> LambdaLR: """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the @@ -288,7 +295,7 @@ def get_scheduler( num_cycles: int = 1, power: float = 1.0, last_epoch: int = -1, -): +) -> LambdaLR: """ Unified API to get any scheduler from its name. diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 4ccc57cd69..6050f314c0 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -28,7 +28,7 @@ from logging import ( WARN, # NOQA WARNING, # NOQA ) -from typing import Optional +from typing import Dict, Optional from tqdm import auto as tqdm_lib @@ -49,7 +49,7 @@ _default_log_level = logging.WARNING _tqdm_active = True -def _get_default_logging_level(): +def _get_default_logging_level() -> int: """ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is not - fall back to `_default_log_level` @@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None: _default_handler = None -def get_log_levels_dict(): +def get_log_levels_dict() -> Dict[str, int]: return log_levels @@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None: _get_library_root_logger().setLevel(verbosity) -def set_verbosity_info(): +def set_verbosity_info() -> None: """Set the verbosity to the `INFO` level.""" return set_verbosity(INFO) -def set_verbosity_warning(): +def set_verbosity_warning() -> None: """Set the verbosity to the `WARNING` level.""" return set_verbosity(WARNING) -def set_verbosity_debug(): +def set_verbosity_debug() -> None: """Set the verbosity to the `DEBUG` level.""" return set_verbosity(DEBUG) -def set_verbosity_error(): +def set_verbosity_error() -> None: """Set the verbosity to the `ERROR` level.""" return set_verbosity(ERROR) @@ -263,7 +263,7 @@ def reset_format() -> None: handler.setFormatter(None) -def warning_advice(self, *args, **kwargs): +def warning_advice(self, *args, **kwargs) -> None: """ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed @@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool: return bool(_tqdm_active) -def enable_progress_bar(): +def enable_progress_bar() -> None: """Enable tqdm progress bar.""" global _tqdm_active _tqdm_active = True -def disable_progress_bar(): +def disable_progress_bar() -> None: """Disable tqdm progress bar.""" global _tqdm_active _tqdm_active = False diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index a057b506ae..01a2973619 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -24,7 +24,7 @@ import numpy as np from .import_utils import is_torch_available -def is_tensor(x): +def is_tensor(x) -> bool: """ Tests if `x` is a `torch.Tensor` or `np.ndarray`. """ @@ -66,7 +66,7 @@ class BaseOutput(OrderedDict): lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) - def __post_init__(self): + def __post_init__(self) -> None: class_fields = fields(self) # Safety and consistency checks @@ -97,14 +97,14 @@ class BaseOutput(OrderedDict): def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - def __getitem__(self, k): + def __getitem__(self, k: Any) -> Any: if isinstance(k, str): inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] - def __setattr__(self, name, value): + def __setattr__(self, name: Any, value: Any) -> None: if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) @@ -123,7 +123,7 @@ class BaseOutput(OrderedDict): args = tuple(getattr(self, field.name) for field in fields(self)) return callable, args, *remaining - def to_tuple(self) -> Tuple[Any]: + def to_tuple(self) -> Tuple[Any, ...]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7955ccb01d..00bc75f41b 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -82,14 +82,14 @@ def randn_tensor( return latents -def is_compiled_module(module): +def is_compiled_module(module) -> bool: """Check whether the module was compiled with torch.compile()""" if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def fourier_filter(x_in, threshold, scale): +def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor: """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: