mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make style
This commit is contained in:
11
Makefile
11
Makefile
@@ -34,13 +34,9 @@ autogenerate_code: deps_table_update
|
||||
# Check that the repo is in a good state
|
||||
|
||||
repo-consistency:
|
||||
python utils/check_copies.py
|
||||
python utils/check_table.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_config_docstrings.py
|
||||
python utils/tests_fetcher.py --sanity_check
|
||||
|
||||
# this target runs checks on all files
|
||||
|
||||
@@ -48,14 +44,13 @@ quality:
|
||||
black --check --preview $(check_dirs)
|
||||
isort --check-only $(check_dirs)
|
||||
flake8 $(check_dirs)
|
||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
extra_style_checks:
|
||||
python utils/custom_init_isort.py
|
||||
python utils/sort_auto_mappings.py
|
||||
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
|
||||
doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source
|
||||
|
||||
# this target runs checks on all files and potentially modifies some of them
|
||||
|
||||
@@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
||||
|
||||
fix-copies:
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
|
||||
@@ -47,12 +47,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
|
||||
def init_git_repo(args, at_init: bool = False):
|
||||
"""
|
||||
Initializes a git repo in `args.hub_model_id`.
|
||||
Args:
|
||||
Initializes a git repo in `args.hub_model_id`.
|
||||
at_init (`bool`, *optional*, defaults to `False`):
|
||||
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
|
||||
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
|
||||
out.
|
||||
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
|
||||
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
|
||||
"""
|
||||
if args.local_rank not in [-1, 0]:
|
||||
return
|
||||
@@ -102,8 +101,8 @@ def push_to_hub(
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
|
||||
Parameters:
|
||||
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
|
||||
commit_message (`str`, *optional*, defaults to `"End of training"`):
|
||||
Message to commit while pushing.
|
||||
blocking (`bool`, *optional*, defaults to `True`):
|
||||
@@ -111,8 +110,8 @@ def push_to_hub(
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to [`create_model_card`].
|
||||
Returns:
|
||||
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
|
||||
the commit and an object to track the progress of the commit if `blocking=True`
|
||||
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
|
||||
commit and an object to track the progress of the commit if `blocking=True`
|
||||
"""
|
||||
|
||||
if args.hub_model_id is None:
|
||||
|
||||
@@ -123,16 +123,16 @@ class ModelMixin(torch.nn.Module):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
|
||||
downloading and saving models as well as a few methods common to all models to:
|
||||
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
||||
and saving models as well as a few methods common to all models to:
|
||||
|
||||
- resize the input embeddings,
|
||||
- prune heads in the self-attention heads.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
|
||||
for this model architecture.
|
||||
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
|
||||
model architecture.
|
||||
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
||||
taking as arguments:
|
||||
|
||||
@@ -227,8 +227,8 @@ class ModelMixin(torch.nn.Module):
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
|
||||
e.g., `./my_model_directory/`.
|
||||
|
||||
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
|
||||
Can be either:
|
||||
@@ -236,13 +236,13 @@ class ModelMixin(torch.nn.Module):
|
||||
- an instance of a class derived from [`ConfigMixin`],
|
||||
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
|
||||
|
||||
ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can
|
||||
be automatically loaded when:
|
||||
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
|
||||
ConfigMixinuration can be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
||||
model).
|
||||
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
|
||||
save directory.
|
||||
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
|
||||
directory.
|
||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
@@ -292,10 +292,10 @@ class ModelMixin(torch.nn.Module):
|
||||
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
||||
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
|
||||
corresponds to a configuration attribute will be used to override said attribute with the
|
||||
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
||||
will be passed to the underlying model's `__init__` function.
|
||||
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
|
||||
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
|
||||
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
|
||||
underlying model's `__init__` function.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
||||
@@ -22,14 +22,12 @@ def get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
Create sinusoidal timestep embeddings.
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
|
||||
@@ -58,9 +58,8 @@ class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
@@ -97,9 +96,8 @@ class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
@@ -143,9 +141,8 @@ class GlideUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
@@ -171,10 +168,9 @@ class GlideUpsample(nn.Module):
|
||||
|
||||
class LDMUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
|
||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
|
||||
@@ -82,8 +82,7 @@ def normalization(channels, swish=0.0):
|
||||
"""
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
:param channels: number of input channels. :return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
@@ -111,8 +110,7 @@ class TimestepBlock(nn.Module):
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, encoder_out=None):
|
||||
@@ -130,9 +128,8 @@ class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
@@ -158,17 +155,13 @@ class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
|
||||
use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
|
||||
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
|
||||
downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -235,8 +228,7 @@ class ResBlock(TimestepBlock):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
if self.updown:
|
||||
@@ -320,8 +312,8 @@ class QKVAttention(nn.Module):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
|
||||
attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
@@ -343,29 +335,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param
|
||||
out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -571,10 +558,8 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
:param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N]
|
||||
Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
hs = []
|
||||
|
||||
@@ -222,11 +222,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
@@ -331,8 +328,7 @@ def normalization(channels, swish=0.0):
|
||||
"""
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
:param channels: number of input channels. :return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
@@ -382,8 +378,7 @@ class TimestepBlock(nn.Module):
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
@@ -399,10 +394,9 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
|
||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
@@ -426,18 +420,14 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
|
||||
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
|
||||
a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
|
||||
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
|
||||
downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -525,8 +515,8 @@ class ResBlock(TimestepBlock):
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
@@ -575,9 +565,8 @@ class QKVAttention(nn.Module):
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
@@ -600,13 +589,9 @@ class QKVAttention(nn.Module):
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
@@ -629,9 +614,8 @@ class QKVAttentionLegacy(nn.Module):
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
||||
T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
@@ -650,31 +634,25 @@ class QKVAttentionLegacy(nn.Module):
|
||||
|
||||
class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||
rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
@@ -975,12 +953,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
|
||||
of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if
|
||||
class-conditional. :return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
@@ -1012,8 +987,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
For usage, see UNet.
|
||||
The half UNet model with attention and timestep embedding. For usage, see UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1197,10 +1171,8 @@ class EncoderUNetModel(nn.Module):
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
|
||||
of timesteps. :return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
|
||||
@@ -111,10 +111,8 @@ class ResidualTemporalBlock(nn.Module):
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
x : [ batch_size x inp_channels x horizon ]
|
||||
t : [ batch_size x embed_dim ]
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
|
||||
out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x) + self.time_mlp(t)
|
||||
out = self.blocks[1](out)
|
||||
|
||||
@@ -136,26 +136,21 @@ def naive_downsample_2d(x, factor=2):
|
||||
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the
|
||||
operations.
|
||||
The fused op is considerably more efficient than performing the same
|
||||
calculation
|
||||
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels =
|
||||
x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
||||
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
||||
`x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -208,25 +203,21 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
||||
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
The fused op is considerably more efficient than performing the same
|
||||
calculation
|
||||
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels =
|
||||
x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
||||
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype
|
||||
as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -258,22 +249,16 @@ def _shape(x, dim):
|
||||
def upsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Upsample a batch of 2D images with the given filter.
|
||||
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
||||
and upsamples each image with the given filter. The filter is normalized so
|
||||
that
|
||||
if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`.
|
||||
Pixels outside the image are assumed to be zero, and the filter is padded
|
||||
with
|
||||
zeros so that its shape is a multiple of the upsampling factor.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
||||
multiple of the upsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
@@ -289,22 +274,16 @@ def upsample_2d(x, k=None, factor=2, gain=1):
|
||||
def downsample_2d(x, k=None, factor=2, gain=1):
|
||||
r"""Downsample a batch of 2D images with the given filter.
|
||||
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
||||
and downsamples each image with the given filter. The filter is normalized
|
||||
so that
|
||||
if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`.
|
||||
Pixels outside the image are assumed to be zero, and the filter is padded
|
||||
with
|
||||
zeros so that its shape is a multiple of the downsampling factor.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
|
||||
@@ -290,7 +290,7 @@ def normalize_numbers(text):
|
||||
return text
|
||||
|
||||
|
||||
""" from https://github.com/keithito/tacotron """
|
||||
""" from https://github.com/keithito/tacotron"""
|
||||
|
||||
|
||||
_pad = "_"
|
||||
@@ -322,8 +322,8 @@ def get_arpabet(word, dictionary):
|
||||
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on
|
||||
{HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
|
||||
@@ -29,8 +29,7 @@ from ..pipeline_utils import DiffusionPipeline
|
||||
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
"""
|
||||
Embed a diffusion step $t$ into a higher dimensional space
|
||||
E.g. the embedding vector in the 128-dimensional space is
|
||||
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
|
||||
E.g. the embedding vector in the 128-dimensional space is [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
|
||||
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
|
||||
|
||||
Parameters:
|
||||
@@ -53,8 +52,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
|
||||
|
||||
"""
|
||||
Below scripts were borrowed from
|
||||
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
|
||||
Below scripts were borrowed from https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -699,9 +699,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
:param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param
|
||||
broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" from https://github.com/jaywalnut310/glow-tts """
|
||||
""" from https://github.com/jaywalnut310/glow-tts"""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
@@ -554,11 +554,9 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
|
||||
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
|
||||
3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
@@ -1055,8 +1053,8 @@ class Decoder(nn.Module):
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
||||
multiplications and allows for post-hoc remapping of indices.
|
||||
"""
|
||||
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
|
||||
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
@@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
|
||||
from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
|
||||
@@ -20,11 +20,10 @@ class EMAModel:
|
||||
):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
|
||||
@@ -89,20 +89,20 @@ class RevisionNotFoundError(HTTPError):
|
||||
|
||||
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
|
||||
`pip install transformers`
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
||||
install transformers`
|
||||
"""
|
||||
|
||||
|
||||
UNIDECODE_IMPORT_ERROR = """
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip:
|
||||
`pip install Unidecode`
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
||||
Unidecode`
|
||||
"""
|
||||
|
||||
|
||||
INFLECT_IMPORT_ERROR = """
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip:
|
||||
`pip install inflect`
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
||||
inflect`
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GradTTS(metaclass=DummyObject):
|
||||
class GradTTSPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers", "inflect", "unidecode"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -31,14 +31,14 @@ class UNetGradTTSModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class Glide(metaclass=DummyObject):
|
||||
class GlidePipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class LatentDiffusion(metaclass=DummyObject):
|
||||
class LatentDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -233,8 +233,8 @@ def disable_propagation() -> None:
|
||||
|
||||
def enable_propagation() -> None:
|
||||
"""
|
||||
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to
|
||||
prevent double logging if the root logger has been configured.
|
||||
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
|
||||
double logging if the root logger has been configured.
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
@@ -22,7 +22,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
GradTTSPipeline,
|
||||
BDDMPipeline,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
@@ -31,6 +30,7 @@ from diffusers import (
|
||||
GlidePipeline,
|
||||
GlideSuperResUNetModel,
|
||||
GlideTextToImageUNetModel,
|
||||
GradTTSPipeline,
|
||||
GradTTSScheduler,
|
||||
LatentDiffusionPipeline,
|
||||
NCSNpp,
|
||||
|
||||
@@ -24,7 +24,7 @@ from doc_builder.style_doc import style_docstrings_in_code
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
TRANSFORMERS_PATH = "src/diffusers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
@@ -76,7 +76,7 @@ def _should_continue(line, indent):
|
||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
def find_code_in_diffusers(object_name):
|
||||
"""Find and return the code source code of `object_name`."""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
@@ -88,9 +88,7 @@ def find_code_in_transformers(object_name):
|
||||
if i < len(parts):
|
||||
module = os.path.join(module, parts[i])
|
||||
if i >= len(parts):
|
||||
raise ValueError(
|
||||
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
|
||||
)
|
||||
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
|
||||
|
||||
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
@@ -121,7 +119,7 @@ def find_code_in_transformers(object_name):
|
||||
return "".join(code_lines)
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
|
||||
|
||||
@@ -167,7 +165,7 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
|
||||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
theoretical_code = find_code_in_diffusers(object_name)
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
|
||||
@@ -235,7 +233,9 @@ def check_copies(overwrite: bool = False):
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
# check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
@@ -348,8 +348,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
|
||||
|
||||
def convert_readme_to_index(model_list):
|
||||
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/transformers/", "")
|
||||
model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/diffusers/", "")
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
@@ -383,9 +383,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
# Fix potential doc links in the README
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
|
||||
readme = f.read()
|
||||
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
|
||||
new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers")
|
||||
new_readme = new_readme.replace(
|
||||
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
|
||||
"https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main"
|
||||
)
|
||||
if new_readme != readme:
|
||||
if overwrite:
|
||||
|
||||
250
utils/custom_init_isort.py
Normal file
250
utils/custom_init_isort.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/diffusers"
|
||||
|
||||
# Pattern that looks at the indentation in a line.
|
||||
_re_indent = re.compile(r"^(\s*)\S")
|
||||
# Pattern that matches `"key":" and puts `key` in group 0.
|
||||
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
|
||||
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
|
||||
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
|
||||
# Pattern that matches `"key",` and puts `key` in group 0.
|
||||
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
||||
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
|
||||
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||
|
||||
|
||||
def get_indent(line):
|
||||
"""Returns the indent in `line`."""
|
||||
search = _re_indent.search(line)
|
||||
return "" if search is None else search.groups()[0]
|
||||
|
||||
|
||||
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
|
||||
"""
|
||||
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
|
||||
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
|
||||
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
|
||||
"""
|
||||
# Let's split the code into lines and move to start_index.
|
||||
index = 0
|
||||
lines = code.split("\n")
|
||||
if start_prompt is not None:
|
||||
while not lines[index].startswith(start_prompt):
|
||||
index += 1
|
||||
blocks = ["\n".join(lines[:index])]
|
||||
else:
|
||||
blocks = []
|
||||
|
||||
# We split into blocks until we get to the `end_prompt` (or the end of the block).
|
||||
current_block = [lines[index]]
|
||||
index += 1
|
||||
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||
current_block.append(lines[index])
|
||||
blocks.append("\n".join(current_block))
|
||||
if index < len(lines) - 1:
|
||||
current_block = [lines[index + 1]]
|
||||
index += 1
|
||||
else:
|
||||
current_block = []
|
||||
else:
|
||||
blocks.append("\n".join(current_block))
|
||||
current_block = [lines[index]]
|
||||
else:
|
||||
current_block.append(lines[index])
|
||||
index += 1
|
||||
|
||||
# Adds current block if it's nonempty.
|
||||
if len(current_block) > 0:
|
||||
blocks.append("\n".join(current_block))
|
||||
|
||||
# Add final block after end_prompt if provided.
|
||||
if end_prompt is not None and index < len(lines):
|
||||
blocks.append("\n".join(lines[index:]))
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def ignore_underscore(key):
|
||||
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
|
||||
|
||||
def _inner(x):
|
||||
return key(x).lower().replace("_", "")
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def sort_objects(objects, key=None):
|
||||
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
|
||||
if key is None:
|
||||
key = noop
|
||||
# Constants are all uppercase, they go first.
|
||||
constants = [obj for obj in objects if key(obj).isupper()]
|
||||
# Classes are not all uppercase but start with a capital, they go second.
|
||||
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
|
||||
# Functions begin with a lowercase, they go last.
|
||||
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||
|
||||
key1 = ignore_underscore(key)
|
||||
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||
|
||||
|
||||
def sort_objects_in_import(import_statement):
|
||||
"""
|
||||
Return the same `import_statement` but with objects properly sorted.
|
||||
"""
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
if "," not in imports:
|
||||
return f"[{imports}]"
|
||||
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
|
||||
|
||||
lines = import_statement.split("\n")
|
||||
if len(lines) > 3:
|
||||
# Here we have to sort internal imports that are on several lines (one per name):
|
||||
# key: [
|
||||
# "object1",
|
||||
# "object2",
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# We may have to ignore one or two lines on each side.
|
||||
idx = 2 if lines[1].strip() == "[" else 1
|
||||
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
|
||||
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
|
||||
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
|
||||
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
|
||||
elif len(lines) == 3:
|
||||
# Here we have to sort internal imports that are on one separate line:
|
||||
# key: [
|
||||
# "object1", "object2", ...
|
||||
# ]
|
||||
if _re_bracket_content.search(lines[1]) is not None:
|
||||
lines[1] = _re_bracket_content.sub(_replace, lines[1])
|
||||
else:
|
||||
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
|
||||
return "\n".join(lines)
|
||||
else:
|
||||
# Finally we have to deal with imports fitting on one line
|
||||
import_statement = _re_bracket_content.sub(_replace, import_statement)
|
||||
return import_statement
|
||||
|
||||
|
||||
def sort_imports(file, check_only=True):
|
||||
"""
|
||||
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
|
||||
"""
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
|
||||
if "_import_structure" not in code:
|
||||
return
|
||||
|
||||
# Blocks of indent level 0
|
||||
main_blocks = split_code_in_indented_blocks(
|
||||
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
|
||||
)
|
||||
|
||||
# We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
|
||||
for block_idx in range(1, len(main_blocks) - 1):
|
||||
# Check if the block contains some `_import_structure`s thingy to sort.
|
||||
block = main_blocks[block_idx]
|
||||
block_lines = block.split("\n")
|
||||
|
||||
# Get to the start of the imports.
|
||||
line_idx = 0
|
||||
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
|
||||
# Skip dummy import blocks
|
||||
if "import dummy" in block_lines[line_idx]:
|
||||
line_idx = len(block_lines)
|
||||
else:
|
||||
line_idx += 1
|
||||
if line_idx >= len(block_lines):
|
||||
continue
|
||||
|
||||
# Ignore beginning and last line: they don't contain anything.
|
||||
internal_block_code = "\n".join(block_lines[line_idx:-1])
|
||||
indent = get_indent(block_lines[1])
|
||||
# Slit the internal block into blocks of indent level 1.
|
||||
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
|
||||
# We have two categories of import key: list or _import_structu[key].append/extend
|
||||
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
|
||||
# Grab the keys, but there is a trap: some lines are empty or jsut comments.
|
||||
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
|
||||
# We only sort the lines with a key.
|
||||
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
|
||||
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
|
||||
|
||||
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
|
||||
count = 0
|
||||
reorderded_blocks = []
|
||||
for i in range(len(internal_blocks)):
|
||||
if keys[i] is None:
|
||||
reorderded_blocks.append(internal_blocks[i])
|
||||
else:
|
||||
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
|
||||
reorderded_blocks.append(block)
|
||||
count += 1
|
||||
|
||||
# And we put our main block back together with its first and last line.
|
||||
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
|
||||
|
||||
if code != "\n".join(main_blocks):
|
||||
if check_only:
|
||||
return True
|
||||
else:
|
||||
print(f"Overwriting {file}.")
|
||||
with open(file, "w") as f:
|
||||
f.write("\n".join(main_blocks))
|
||||
|
||||
|
||||
def sort_imports_in_all_inits(check_only=True):
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
|
||||
if result:
|
||||
failures = [os.path.join(root, "__init__.py")]
|
||||
if len(failures) > 0:
|
||||
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_imports_in_all_inits(check_only=args.check_only)
|
||||
Reference in New Issue
Block a user