mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve docs and type hints (#5759)
* improvement: docs and type hints * improvement: docs and type hints minor refactor * improvement: docs and type hints * update with suggestions from review Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -126,14 +126,14 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def normalize(images):
|
||||
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Normalize an image array to [-1,1].
|
||||
"""
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images):
|
||||
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Denormalize an image array to [0,1].
|
||||
"""
|
||||
@@ -159,10 +159,10 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
):
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
This function return the height and width that are downscaled to the next integer multiple of
|
||||
`vae_scale_factor`.
|
||||
@@ -202,12 +202,24 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height to resize to.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width to resize to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The resized image.
|
||||
"""
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
@@ -227,7 +239,15 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
create a mask
|
||||
Create a mask.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image input, should be a PIL image.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
@@ -327,7 +347,23 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image: torch.FloatTensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
):
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
Args:
|
||||
image (`torch.FloatTensor`):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
||||
The postprocessed image.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
||||
@@ -390,7 +426,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
Convert a NumPy image or a batch of images to a PIL image.
|
||||
"""
|
||||
@@ -406,7 +442,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def rgblike_to_depthmap(image):
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image: RGB-like depth image
|
||||
@@ -416,7 +452,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
def numpy_to_depth(self, images):
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
Convert a NumPy depth image or a batch of images to a PIL image.
|
||||
"""
|
||||
@@ -441,7 +477,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
image: torch.FloatTensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
):
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
Args:
|
||||
image (`torch.FloatTensor`):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
||||
The postprocessed image.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
||||
|
||||
@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_down_block: int = 1,
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int] = (64,),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_up_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
@@ -109,7 +109,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.use_tiling = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
@@ -125,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, image, mask)
|
||||
|
||||
@@ -142,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
decoded = self._decode(z, image, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
@@ -157,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
|
||||
@@ -322,13 +322,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a, b, blend_extent):
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a, b, blend_extent):
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
|
||||
@@ -96,18 +96,18 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
|
||||
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
|
||||
latent_magnitude: int = 3,
|
||||
latent_shift: float = 0.5,
|
||||
force_upcast: float = False,
|
||||
force_upcast: bool = False,
|
||||
scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -147,33 +147,33 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def scale_latents(self, x):
|
||||
def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
||||
|
||||
def unscale_latents(self, x):
|
||||
def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
||||
|
||||
def enable_slicing(self):
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
def enable_tiling(self, use_tiling: bool = True) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
@@ -181,7 +181,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
@@ -197,13 +197,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
||||
plain `tuple` is returned.
|
||||
`torch.FloatTensor`: Encoded batch of images.
|
||||
"""
|
||||
# scale of encoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
@@ -249,13 +245,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
`torch.FloatTensor`: Encoded batch of images.
|
||||
"""
|
||||
# scale of decoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
|
||||
@@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
scaling_factor=0.18215,
|
||||
latent_channels=4,
|
||||
encoder_act_fn="silu",
|
||||
encoder_block_out_channels=(128, 256, 512, 512),
|
||||
encoder_double_z=True,
|
||||
encoder_down_block_types=(
|
||||
scaling_factor: float = 0.18215,
|
||||
latent_channels: int = 4,
|
||||
encoder_act_fn: str = "silu",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_double_z: bool = True,
|
||||
encoder_down_block_types: Tuple[str, ...] = (
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
),
|
||||
encoder_in_channels=3,
|
||||
encoder_layers_per_block=2,
|
||||
encoder_norm_num_groups=32,
|
||||
encoder_out_channels=4,
|
||||
decoder_add_attention=False,
|
||||
decoder_block_out_channels=(320, 640, 1024, 1024),
|
||||
decoder_down_block_types=(
|
||||
encoder_in_channels: int = 3,
|
||||
encoder_layers_per_block: int = 2,
|
||||
encoder_norm_num_groups: int = 32,
|
||||
encoder_out_channels: int = 4,
|
||||
decoder_add_attention: bool = False,
|
||||
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
|
||||
decoder_down_block_types: Tuple[str, ...] = (
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
),
|
||||
decoder_downsample_padding=1,
|
||||
decoder_in_channels=7,
|
||||
decoder_layers_per_block=3,
|
||||
decoder_norm_eps=1e-05,
|
||||
decoder_norm_num_groups=32,
|
||||
decoder_num_train_timesteps=1024,
|
||||
decoder_out_channels=6,
|
||||
decoder_resnet_time_scale_shift="scale_shift",
|
||||
decoder_time_embedding_type="learned",
|
||||
decoder_up_block_types=(
|
||||
decoder_downsample_padding: int = 1,
|
||||
decoder_in_channels: int = 7,
|
||||
decoder_layers_per_block: int = 3,
|
||||
decoder_norm_eps: float = 1e-05,
|
||||
decoder_norm_num_groups: int = 32,
|
||||
decoder_num_train_timesteps: int = 1024,
|
||||
decoder_out_channels: int = 6,
|
||||
decoder_resnet_time_scale_shift: str = "scale_shift",
|
||||
decoder_time_embedding_type: str = "learned",
|
||||
decoder_up_block_types: Tuple[str, ...] = (
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
@@ -304,8 +304,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
z: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
num_inference_steps=2,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
num_inference_steps: int = 2,
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
|
||||
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
||||
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=x_0)
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
|
||||
def blend_v(self, a, b, blend_extent):
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
|
||||
def blend_h(self, a, b, blend_extent):
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
@@ -407,7 +407,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
@@ -415,6 +415,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
generator (`torch.Generator`, *optional*, defaults to `None`):
|
||||
Generator to use for sampling.
|
||||
|
||||
Returns:
|
||||
[`DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
|
||||
@@ -76,7 +76,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -171,6 +171,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -182,14 +185,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -197,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -211,9 +214,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads=64,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -426,7 +429,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
@@ -570,7 +573,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -635,7 +638,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -653,7 +656,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
|
||||
@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput):
|
||||
|
||||
class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
conditioning_embedding_channels: int
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
self.conv_in = nn.Conv(
|
||||
self.block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
@@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, conditioning):
|
||||
def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = nn.silu(embedding)
|
||||
|
||||
@@ -148,17 +148,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"""
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
@@ -166,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -182,7 +182,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -312,21 +312,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
controlnet_cond,
|
||||
sample: jnp.ndarray,
|
||||
timesteps: Union[jnp.ndarray, float, int],
|
||||
encoder_hidden_states: jnp.ndarray,
|
||||
controlnet_cond: jnp.ndarray,
|
||||
conditioning_scale: float = 1.0,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxControlNetOutput, Tuple]:
|
||||
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
||||
conditioning_scale: (`float`) the scale factor for controlnet outputs
|
||||
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
@@ -335,8 +335,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
channel_order = self.controlnet_conditioning_channel_order
|
||||
if channel_order == "bgr":
|
||||
|
||||
@@ -18,13 +18,14 @@ import inspect
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import create_repo
|
||||
from torch import Tensor, device, nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
@@ -61,7 +62,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
try:
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).device
|
||||
@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module):
|
||||
return first_tuple[1].device
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
try:
|
||||
params = tuple(parameter.parameters())
|
||||
if len(params) > 0:
|
||||
@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
def load(module: torch.nn.Module, prefix: str = ""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
def enable_gradient_checkpointing(self) -> None:
|
||||
"""
|
||||
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
||||
*checkpoint activations* in other frameworks).
|
||||
@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
def disable_gradient_checkpointing(self) -> None:
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
||||
*checkpoint activations* in other frameworks).
|
||||
@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
||||
r"""
|
||||
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
||||
|
||||
@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
def disable_xformers_memory_efficient_attention(self) -> None:
|
||||
r"""
|
||||
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
||||
"""
|
||||
@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
save_function: Optional[Callable] = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
push_to_hub: bool = False,
|
||||
@@ -910,10 +917,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
model,
|
||||
state_dict,
|
||||
state_dict: OrderedDict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=False,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
):
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
@@ -1011,7 +1018,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
@@ -1063,7 +1070,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
|
||||
def _convert_deprecated_attention_blocks(self, state_dict):
|
||||
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
||||
deprecated_attention_block_paths = []
|
||||
|
||||
def recursive_find_attn_block(name, module):
|
||||
@@ -1107,7 +1114,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if f"{path}.proj_attn.bias" in state_dict:
|
||||
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
||||
|
||||
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
||||
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
||||
deprecated_attention_block_modules = []
|
||||
|
||||
def recursive_find_attn_block(module):
|
||||
@@ -1134,10 +1141,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
del module.to_v
|
||||
del module.to_out
|
||||
|
||||
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
||||
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
||||
deprecated_attention_block_modules = []
|
||||
|
||||
def recursive_find_attn_block(module):
|
||||
def recursive_find_attn_block(module) -> None:
|
||||
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
||||
deprecated_attention_block_modules.append(module)
|
||||
|
||||
|
||||
@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
batch_size: int = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
|
||||
@@ -164,7 +164,9 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
@@ -256,7 +258,7 @@ class Downsample2D(nn.Module):
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
@@ -280,7 +282,7 @@ class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
@@ -292,7 +294,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
@@ -307,12 +309,12 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -320,17 +322,21 @@ class FirUpsample2D(nn.Module):
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -392,7 +398,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -418,7 +424,7 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
@@ -433,30 +439,35 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
||||
factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
||||
same datatype as `x`.
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -492,7 +503,7 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -682,7 +693,9 @@ class ResnetBlock2D(nn.Module):
|
||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
||||
)
|
||||
|
||||
def forward(self, input_tensor, temb, scale: float = 1.0):
|
||||
def forward(
|
||||
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
||||
@@ -778,7 +791,7 @@ class Conv1dBlock(nn.Module):
|
||||
out_channels (`int`): Number of output channels.
|
||||
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
||||
n_groups (`int`, default `8`): Number of groups to separate the channels into.
|
||||
activation (`str`, defaults `mish`): Name of the activation function.
|
||||
activation (`str`, defaults to `mish`): Name of the activation function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -853,8 +866,8 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
@@ -862,14 +875,19 @@ def upsample_2d(
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
@@ -892,8 +910,8 @@ def upsample_2d(
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
@@ -901,14 +919,19 @@ def downsample_2d(
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
freq_shift: int = 0
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
addition_embed_type_num_heads: int = 64
|
||||
@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
}
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
sample: jnp.ndarray,
|
||||
timesteps: Union[jnp.ndarray, float, int],
|
||||
encoder_hidden_states: jnp.ndarray,
|
||||
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
|
||||
mid_block_additional_residual: Optional[jnp.ndarray] = None,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
||||
) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -26,26 +26,26 @@ from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
temporal_num_attention_heads=8,
|
||||
temporal_max_seq_length=32,
|
||||
):
|
||||
down_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
downsample_padding: Optional[int] = None,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]:
|
||||
if down_block_type == "DownBlock3D":
|
||||
return DownBlock3D(
|
||||
num_layers=num_layers,
|
||||
@@ -123,28 +123,28 @@ def get_down_block(
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels,
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
resolution_idx=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
temporal_num_attention_heads=8,
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_max_seq_length=32,
|
||||
):
|
||||
up_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
add_upsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
resolution_idx: Optional[int] = None,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]:
|
||||
if up_block_type == "UpBlock3D":
|
||||
return UpBlock3D(
|
||||
num_layers=num_layers,
|
||||
@@ -236,12 +236,12 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=True,
|
||||
upcast_attention=False,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = True,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -328,13 +328,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
||||
for attn, temp_attn, resnet, temp_conv in zip(
|
||||
@@ -368,15 +368,15 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
num_attention_heads: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
add_downsample: bool = True,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -454,13 +454,13 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
@@ -503,9 +503,9 @@ class DownBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
downsample_padding: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -552,7 +552,9 @@ class DownBlock3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, num_frames=1):
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
||||
@@ -584,15 +586,15 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resolution_idx=None,
|
||||
num_attention_heads: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resolution_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -667,15 +669,15 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -738,9 +740,9 @@ class UpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
resolution_idx=None,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
resolution_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -784,7 +786,14 @@ class UpBlock3D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -833,12 +842,12 @@ class DownBlockMotion(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
temporal_num_attention_heads=1,
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_max_seq_length=32,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
downsample_padding: int = 1,
|
||||
temporal_num_attention_heads: int = 1,
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -890,7 +899,13 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
num_frames: int = 1,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
blocks = zip(self.resnets, self.motion_modules)
|
||||
@@ -944,19 +959,19 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
attention_type="default",
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_num_attention_heads=8,
|
||||
temporal_max_seq_length=32,
|
||||
num_attention_heads: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
add_downsample: bool = True,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
attention_type: str = "default",
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1043,14 +1058,14 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
num_frames=1,
|
||||
encoder_attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
additional_residuals=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
output_states = ()
|
||||
|
||||
@@ -1121,7 +1136,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
resolution_idx: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
@@ -1130,18 +1145,18 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
attention_type="default",
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_num_attention_heads=8,
|
||||
temporal_max_seq_length=32,
|
||||
num_attention_heads: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
attention_type: str = "default",
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1232,8 +1247,8 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames=1,
|
||||
):
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
@@ -1317,7 +1332,7 @@ class UpBlockMotion(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
resolution_idx: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -1325,12 +1340,12 @@ class UpBlockMotion(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
temporal_norm_num_groups=32,
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_num_attention_heads=8,
|
||||
temporal_max_seq_length=32,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
temporal_norm_num_groups: int = 32,
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1381,8 +1396,14 @@ class UpBlockMotion(nn.Module):
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
|
||||
):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size=None,
|
||||
scale: float = 1.0,
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -1457,16 +1478,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
attention_type="default",
|
||||
temporal_num_attention_heads=1,
|
||||
temporal_cross_attention_dim=None,
|
||||
temporal_max_seq_length=32,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: float = False,
|
||||
use_linear_projection: float = False,
|
||||
upcast_attention: float = False,
|
||||
attention_type: str = "default",
|
||||
temporal_num_attention_heads: int = 1,
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1560,7 +1581,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames=1,
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
|
||||
@@ -98,14 +98,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -302,7 +307,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -404,7 +409,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
@@ -460,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -510,7 +515,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
The [`UNet3DConditionModel`] forward method.
|
||||
|
||||
|
||||
@@ -50,14 +50,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class MotionModules(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
layers_per_block=2,
|
||||
num_attention_heads=8,
|
||||
attention_bias=False,
|
||||
cross_attention_dim=None,
|
||||
activation_fn="geglu",
|
||||
norm_num_groups=32,
|
||||
max_seq_length=32,
|
||||
in_channels: int,
|
||||
layers_per_block: int = 2,
|
||||
num_attention_heads: int = 8,
|
||||
attention_bias: bool = False,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_num_groups: int = 32,
|
||||
max_seq_length: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.motion_modules = nn.ModuleList([])
|
||||
@@ -82,13 +82,13 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
block_out_channels=(320, 640, 1280, 1280),
|
||||
motion_layers_per_block=2,
|
||||
motion_mid_block_layers_per_block=1,
|
||||
motion_num_attention_heads=8,
|
||||
motion_norm_num_groups=32,
|
||||
motion_max_seq_length=32,
|
||||
use_motion_mid_block=True,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
motion_layers_per_block: int = 2,
|
||||
motion_mid_block_layers_per_block: int = 1,
|
||||
motion_num_attention_heads: int = 8,
|
||||
motion_norm_num_groups: int = 32,
|
||||
motion_max_seq_length: int = 32,
|
||||
use_motion_mid_block: bool = True,
|
||||
):
|
||||
"""Container to store AnimateDiff Motion Modules
|
||||
|
||||
@@ -182,29 +182,29 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlockMotion",
|
||||
"CrossAttnUpBlockMotion",
|
||||
"CrossAttnUpBlockMotion",
|
||||
"CrossAttnUpBlockMotion",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
use_linear_projection: bool = False,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
|
||||
motion_max_seq_length: Optional[int] = 32,
|
||||
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
use_motion_mid_block: int = True,
|
||||
):
|
||||
@@ -448,7 +448,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
return model
|
||||
|
||||
def freeze_unet2d_params(self):
|
||||
def freeze_unet2d_params(self) -> None:
|
||||
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
|
||||
unfrozen for fine tuning.
|
||||
"""
|
||||
@@ -472,9 +472,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for param in motion_modules.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
return
|
||||
|
||||
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]):
|
||||
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
|
||||
for i, down_block in enumerate(motion_adapter.down_blocks):
|
||||
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
|
||||
for i, up_block in enumerate(motion_adapter.up_blocks):
|
||||
@@ -492,7 +490,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
variant: Optional[str] = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
state_dict = self.state_dict()
|
||||
|
||||
# Extract all motion modules
|
||||
@@ -582,7 +580,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
@@ -612,7 +610,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self):
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
@@ -624,7 +622,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
def set_default_attn_processor(self) -> None:
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
@@ -639,12 +637,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
@@ -669,7 +667,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
def disable_freeu(self) -> None:
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
@@ -688,7 +686,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
|
||||
r"""
|
||||
The [`UNetMotionModel`] forward method.
|
||||
|
||||
|
||||
@@ -148,7 +148,9 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]:
|
||||
r"""
|
||||
The [`VQModel`] forward method.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class SchedulerType(Enum):
|
||||
PIECEWISE_CONSTANT = "piecewise_constant"
|
||||
|
||||
|
||||
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
||||
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
||||
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
||||
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
||||
increases linearly between 0 and the initial lr set in the optimizer.
|
||||
@@ -78,7 +78,7 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
|
||||
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
||||
|
||||
@@ -120,7 +120,9 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_
|
||||
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||
def get_linear_schedule_with_warmup(
|
||||
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
|
||||
) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
||||
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
||||
@@ -151,7 +153,7 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
||||
):
|
||||
) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
@@ -185,7 +187,7 @@ def get_cosine_schedule_with_warmup(
|
||||
|
||||
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
||||
):
|
||||
) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
||||
@@ -219,8 +221,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
|
||||
|
||||
def get_polynomial_decay_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
||||
):
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
lr_end: float = 1e-7,
|
||||
power: float = 1.0,
|
||||
last_epoch: int = -1,
|
||||
) -> LambdaLR:
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
||||
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
||||
@@ -288,7 +295,7 @@ def get_scheduler(
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
) -> LambdaLR:
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from logging import (
|
||||
WARN, # NOQA
|
||||
WARNING, # NOQA
|
||||
)
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def _get_default_logging_level():
|
||||
def _get_default_logging_level() -> int:
|
||||
"""
|
||||
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
||||
not - fall back to `_default_log_level`
|
||||
@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def get_log_levels_dict():
|
||||
def get_log_levels_dict() -> Dict[str, int]:
|
||||
return log_levels
|
||||
|
||||
|
||||
@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
|
||||
_get_library_root_logger().setLevel(verbosity)
|
||||
|
||||
|
||||
def set_verbosity_info():
|
||||
def set_verbosity_info() -> None:
|
||||
"""Set the verbosity to the `INFO` level."""
|
||||
return set_verbosity(INFO)
|
||||
|
||||
|
||||
def set_verbosity_warning():
|
||||
def set_verbosity_warning() -> None:
|
||||
"""Set the verbosity to the `WARNING` level."""
|
||||
return set_verbosity(WARNING)
|
||||
|
||||
|
||||
def set_verbosity_debug():
|
||||
def set_verbosity_debug() -> None:
|
||||
"""Set the verbosity to the `DEBUG` level."""
|
||||
return set_verbosity(DEBUG)
|
||||
|
||||
|
||||
def set_verbosity_error():
|
||||
def set_verbosity_error() -> None:
|
||||
"""Set the verbosity to the `ERROR` level."""
|
||||
return set_verbosity(ERROR)
|
||||
|
||||
@@ -263,7 +263,7 @@ def reset_format() -> None:
|
||||
handler.setFormatter(None)
|
||||
|
||||
|
||||
def warning_advice(self, *args, **kwargs):
|
||||
def warning_advice(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
warning will not be printed
|
||||
@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
|
||||
return bool(_tqdm_active)
|
||||
|
||||
|
||||
def enable_progress_bar():
|
||||
def enable_progress_bar() -> None:
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def disable_progress_bar():
|
||||
def disable_progress_bar() -> None:
|
||||
"""Disable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = False
|
||||
|
||||
@@ -24,7 +24,7 @@ import numpy as np
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
def is_tensor(x) -> bool:
|
||||
"""
|
||||
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
||||
"""
|
||||
@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict):
|
||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
class_fields = fields(self)
|
||||
|
||||
# Safety and consistency checks
|
||||
@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict):
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __getitem__(self, k):
|
||||
def __getitem__(self, k: Any) -> Any:
|
||||
if isinstance(k, str):
|
||||
inner_dict = dict(self.items())
|
||||
return inner_dict[k]
|
||||
else:
|
||||
return self.to_tuple()[k]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
def __setattr__(self, name: Any, value: Any) -> None:
|
||||
if name in self.keys() and value is not None:
|
||||
# Don't call self.__setitem__ to avoid recursion errors
|
||||
super().__setitem__(name, value)
|
||||
@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict):
|
||||
args = tuple(getattr(self, field.name) for field in fields(self))
|
||||
return callable, args, *remaining
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
def to_tuple(self) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||
"""
|
||||
|
||||
@@ -82,14 +82,14 @@ def randn_tensor(
|
||||
return latents
|
||||
|
||||
|
||||
def is_compiled_module(module):
|
||||
def is_compiled_module(module) -> bool:
|
||||
"""Check whether the module was compiled with torch.compile()"""
|
||||
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
|
||||
return False
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
This version of the method comes from here:
|
||||
|
||||
Reference in New Issue
Block a user